ChatGPT models

This commit is contained in:
Mayeul Aubin 2025-06-05 17:30:26 +02:00
parent 2b9830211e
commit 26af105195
4 changed files with 193 additions and 0 deletions

View file

@ -0,0 +1,38 @@
import torch
class BaseModel(torch.nn.Module):
def __init__(self,
N:int=128,
in_channels:int=2,
out_channels:int=1,
style_parameters:int=2,
device: torch.device = torch.device('cpu')):
"""
Base class for all models.
Parameters:
- N: Size of the input data: data will have the shape (B, C, N,N,N) with B the batch size, C the number of channels.
- in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1).
- style_parameters: Number of style parameters (default is 2).
- device: Device to load the model onto (default is CPU).
"""
super().__init__()
self.N = N
self.in_channels = in_channels
self.out_channels = out_channels
self.style_parameters = style_parameters
self.device = device
self.to(self.device)
def forward(self, x, style):
"""
Forward pass of the model.
Should be implemented in subclasses.
Parameters:
- x: Input tensor of shape (B, C, N, N, N) where B is the batch size, C is the number of channels.
- style: Style parameters tensor of shape (B, S) where S is the number of style parameters.
Returns:
- Output tensor of shape (B, C_out, N, N, N) where C_out is the number of output channels.
"""
raise NotImplementedError("Forward method must be implemented in subclasses.")