From 6a97fd27ec09193feb7d99f508b9442da4ed32f9 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Wed, 25 Jun 2025 15:48:49 +0200 Subject: [PATCH] xavier init + skip connection --- sCOCA_ML/models/FNO_models.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sCOCA_ML/models/FNO_models.py b/sCOCA_ML/models/FNO_models.py index ff52a2f..f8234af 100644 --- a/sCOCA_ML/models/FNO_models.py +++ b/sCOCA_ML/models/FNO_models.py @@ -33,10 +33,9 @@ class FourierSpaceBlock3D(nn.Module): self.out_channels = out_channels self.film = FiLM(out_channels, style_dim) if style_dim else None self.filtering = filtering if filtering is not None else (None, None, None) - self.scale = (1 / (in_channels * out_channels)) - self.weights = nn.Parameter( - self.scale * torch.randn(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2) - ) + weights = torch.empty(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2) + nn.init.xavier_uniform_(weights) + self.weights = nn.Parameter(weights) def compl_mul3d(self, input, weights): # input: (B, I, X, Y, Z), weights: (I, O, X, Y, Z) @@ -191,6 +190,6 @@ class FNO3D(BaseModel): """ x = self.embedding(x) for block in self.fno_blocks: - x = block(x, style) + x = x + block(x, style) x = self.final_conv(x) return x