Fixes and cleaning up
This commit is contained in:
parent
55b1a72ef4
commit
8544ff07e8
@ -20,6 +20,8 @@ class FieldDataset(Dataset):
|
||||
`in_norms` is a list of of functions to normalize the input fields.
|
||||
Likewise for `tgt_norms`.
|
||||
|
||||
NOTE that vector fields are assumed if numbers of channels and dimensions are equal.
|
||||
|
||||
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||
In 3D these form the full octahedral symmetry, the Oh group of order 48.
|
||||
In 2D this is the dihedral group D4 of order 8.
|
||||
|
@ -3,7 +3,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .narrow import narrow_like
|
||||
from .swish import Swish
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
|
@ -35,7 +35,10 @@ class VNet(nn.Module):
|
||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||
|
||||
self.bypass = in_chan == out_chan
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x):
|
||||
if self.bypass:
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .data import FieldDataset
|
||||
from .data import norms
|
||||
from . import models
|
||||
from .models import narrow_cast
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
@ -68,11 +69,13 @@ def test(args):
|
||||
if args.in_norms is not None:
|
||||
start = 0
|
||||
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
norm(input[:, start:stop], undo=True)
|
||||
start = stop
|
||||
if args.tgt_norms is not None:
|
||||
start = 0
|
||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||
norm(output[:, start:stop], undo=True)
|
||||
norm(target[:, start:stop], undo=True)
|
||||
start = stop
|
||||
|
@ -397,8 +397,7 @@ def dist_init(rank, args):
|
||||
|
||||
with open(dist_file, mode='w') as f:
|
||||
f.write(args.dist_addr)
|
||||
|
||||
if rank != 0:
|
||||
else:
|
||||
while not os.path.exists(dist_file):
|
||||
time.sleep(1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user