xavier init + skip connection

This commit is contained in:
Mayeul Aubin 2025-06-25 15:48:49 +02:00
parent c0b1f656ce
commit 6a97fd27ec

View file

@ -33,10 +33,9 @@ class FourierSpaceBlock3D(nn.Module):
self.out_channels = out_channels
self.film = FiLM(out_channels, style_dim) if style_dim else None
self.filtering = filtering if filtering is not None else (None, None, None)
self.scale = (1 / (in_channels * out_channels))
self.weights = nn.Parameter(
self.scale * torch.randn(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2)
)
weights = torch.empty(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2)
nn.init.xavier_uniform_(weights)
self.weights = nn.Parameter(weights)
def compl_mul3d(self, input, weights):
# input: (B, I, X, Y, Z), weights: (I, O, X, Y, Z)
@ -191,6 +190,6 @@ class FNO3D(BaseModel):
"""
x = self.embedding(x)
for block in self.fno_blocks:
x = block(x, style)
x = x + block(x, style)
x = self.final_conv(x)
return x