import torch class BaseModel(torch.nn.Module): def __init__(self, N:int=128, in_channels:int=2, out_channels:int=1, style_parameters:int=2, device: torch.device = torch.device('cpu')): """ Base class for all models. Parameters: - N: Size of the input data: data will have the shape (B, C, N,N,N) with B the batch size, C the number of channels. - in_channels: Number of input channels (default is 2). - out_channels: Number of output channels (default is 1). - style_parameters: Number of style parameters (default is 2). - device: Device to load the model onto (default is CPU). """ super().__init__() self.N = N self.in_channels = in_channels self.out_channels = out_channels self.style_parameters = style_parameters self.device = device self.to(self.device) def forward(self, x, style): """ Forward pass of the model. Should be implemented in subclasses. Parameters: - x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels. - style: Style parameters tensor of shape (B, S) where S is the number of style parameters. Returns: - Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels. """ raise NotImplementedError("Forward method must be implemented in subclasses.")