""" FiLM: Feature-wise Linear Modulation For applying FiLM to condition feature maps on style parameters. """ import torch.nn as nn class FiLM(nn.Module): def __init__(self, num_features, style_dim): super(FiLM, self).__init__() self.gamma = nn.Linear(style_dim, num_features) self.beta = nn.Linear(style_dim, num_features) def forward(self, x, style): gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) return gamma * x + beta