diff --git a/sCOCA_ML/models/FiLM.py b/sCOCA_ML/models/FiLM.py new file mode 100644 index 0000000..e7056b5 --- /dev/null +++ b/sCOCA_ML/models/FiLM.py @@ -0,0 +1,16 @@ +""" +FiLM: Feature-wise Linear Modulation +For applying FiLM to condition feature maps on style parameters. +""" +import torch.nn as nn + +class FiLM(nn.Module): + def __init__(self, num_features, style_dim): + super(FiLM, self).__init__() + self.gamma = nn.Linear(style_dim, num_features) + self.beta = nn.Linear(style_dim, num_features) + + def forward(self, x, style): + gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + return gamma * x + beta