Fixes and cleaning up

This commit is contained in:
Yin Li 2021-03-18 14:16:17 -04:00
parent 55b1a72ef4
commit 8544ff07e8
5 changed files with 10 additions and 4 deletions

View file

@ -20,6 +20,8 @@ class FieldDataset(Dataset):
`in_norms` is a list of of functions to normalize the input fields.
Likewise for `tgt_norms`.
NOTE that vector fields are assumed if numbers of channels and dimensions are equal.
Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry, the Oh group of order 48.
In 2D this is the dihedral group D4 of order 8.

View file

@ -3,7 +3,6 @@ import torch
import torch.nn as nn
from .narrow import narrow_like
from .swish import Swish
class ConvBlock(nn.Module):

View file

@ -35,7 +35,10 @@ class VNet(nn.Module):
self.up_r0 = ConvBlock(64, seq='UBA')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
self.bypass = in_chan == out_chan
if bypass is None:
self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x):
if self.bypass:

View file

@ -5,6 +5,7 @@ import torch
from torch.utils.data import DataLoader
from .data import FieldDataset
from .data import norms
from . import models
from .models import narrow_cast
from .utils import import_attr, load_model_state_dict
@ -68,11 +69,13 @@ def test(args):
if args.in_norms is not None:
start = 0
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(input[:, start:stop], undo=True)
start = stop
if args.tgt_norms is not None:
start = 0
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True)
norm(target[:, start:stop], undo=True)
start = stop

View file

@ -397,8 +397,7 @@ def dist_init(rank, args):
with open(dist_file, mode='w') as f:
f.write(args.dist_addr)
if rank != 0:
else:
while not os.path.exists(dist_file):
time.sleep(1)