ML_GravPotBCs/sCOCA_ML/models/FiLM.py
2025-06-25 09:35:46 +02:00

16 lines
569 B
Python

"""
FiLM: Feature-wise Linear Modulation
For applying FiLM to condition feature maps on style parameters.
"""
import torch.nn as nn
class FiLM(nn.Module):
def __init__(self, num_features, style_dim):
super(FiLM, self).__init__()
self.gamma = nn.Linear(style_dim, num_features)
self.beta = nn.Linear(style_dim, num_features)
def forward(self, x, style):
gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
return gamma * x + beta