ChatGPT models
This commit is contained in:
parent
2b9830211e
commit
26af105195
4 changed files with 193 additions and 0 deletions
38
sCOCA_ML/models/base_class_models.py
Normal file
38
sCOCA_ML/models/base_class_models.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self,
|
||||
N:int=128,
|
||||
in_channels:int=2,
|
||||
out_channels:int=1,
|
||||
style_parameters:int=2,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
"""
|
||||
Base class for all models.
|
||||
Parameters:
|
||||
- N: Size of the input data: data will have the shape (B, C, N,N,N) with B the batch size, C the number of channels.
|
||||
- in_channels: Number of input channels (default is 2).
|
||||
- out_channels: Number of output channels (default is 1).
|
||||
- style_parameters: Number of style parameters (default is 2).
|
||||
- device: Device to load the model onto (default is CPU).
|
||||
"""
|
||||
super().__init__()
|
||||
self.N = N
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.style_parameters = style_parameters
|
||||
self.device = device
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, x, style):
|
||||
"""
|
||||
Forward pass of the model.
|
||||
Should be implemented in subclasses.
|
||||
Parameters:
|
||||
- x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels.
|
||||
- style: Style parameters tensor of shape (B, S) where S is the number of style parameters.
|
||||
Returns:
|
||||
- Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels.
|
||||
"""
|
||||
raise NotImplementedError("Forward method must be implemented in subclasses.")
|
Loading…
Add table
Add a link
Reference in a new issue