xavier init + skip connection
This commit is contained in:
parent
c0b1f656ce
commit
6a97fd27ec
1 changed files with 4 additions and 5 deletions
|
@ -33,10 +33,9 @@ class FourierSpaceBlock3D(nn.Module):
|
|||
self.out_channels = out_channels
|
||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||
self.filtering = filtering if filtering is not None else (None, None, None)
|
||||
self.scale = (1 / (in_channels * out_channels))
|
||||
self.weights = nn.Parameter(
|
||||
self.scale * torch.randn(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2)
|
||||
)
|
||||
weights = torch.empty(in_channels, out_channels, self.filtering[0] if self.filtering[0] else 1, self.filtering[1] if self.filtering[1] else 1, self.filtering[2] if self.filtering[2] else 1, 2)
|
||||
nn.init.xavier_uniform_(weights)
|
||||
self.weights = nn.Parameter(weights)
|
||||
|
||||
def compl_mul3d(self, input, weights):
|
||||
# input: (B, I, X, Y, Z), weights: (I, O, X, Y, Z)
|
||||
|
@ -191,6 +190,6 @@ class FNO3D(BaseModel):
|
|||
"""
|
||||
x = self.embedding(x)
|
||||
for block in self.fno_blocks:
|
||||
x = block(x, style)
|
||||
x = x + block(x, style)
|
||||
x = self.final_conv(x)
|
||||
return x
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue