38 lines
1.5 KiB
Python
38 lines
1.5 KiB
Python
import torch
|
|
|
|
|
|
class BaseModel(torch.nn.Module):
|
|
def __init__(self,
|
|
N:int=128,
|
|
in_channels:int=2,
|
|
out_channels:int=1,
|
|
style_parameters:int=2,
|
|
device: torch.device = torch.device('cpu')):
|
|
"""
|
|
Base class for all models.
|
|
Parameters:
|
|
- N: Size of the input data: data will have the shape (B, C, N,N,N) with B the batch size, C the number of channels.
|
|
- in_channels: Number of input channels (default is 2).
|
|
- out_channels: Number of output channels (default is 1).
|
|
- style_parameters: Number of style parameters (default is 2).
|
|
- device: Device to load the model onto (default is CPU).
|
|
"""
|
|
super().__init__()
|
|
self.N = N
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.style_parameters = style_parameters
|
|
self.device = device
|
|
self.to(self.device)
|
|
|
|
def forward(self, x, style):
|
|
"""
|
|
Forward pass of the model.
|
|
Should be implemented in subclasses.
|
|
Parameters:
|
|
- x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels.
|
|
- style: Style parameters tensor of shape (B, S) where S is the number of style parameters.
|
|
Returns:
|
|
- Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels.
|
|
"""
|
|
raise NotImplementedError("Forward method must be implemented in subclasses.")
|