From 8544ff07e8eb11ec6a38b2006fff145e2a3dbd6c Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 18 Mar 2021 14:16:17 -0400 Subject: [PATCH] Fixes and cleaning up --- map2map/data/fields.py | 2 ++ map2map/models/conv.py | 1 - map2map/models/vnet.py | 5 ++++- map2map/test.py | 3 +++ map2map/train.py | 3 +-- 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 798a67b..e80a0f9 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -20,6 +20,8 @@ class FieldDataset(Dataset): `in_norms` is a list of of functions to normalize the input fields. Likewise for `tgt_norms`. + NOTE that vector fields are assumed if numbers of channels and dimensions are equal. + Scalar and vector fields can be augmented by flipping and permutating the axes. In 3D these form the full octahedral symmetry, the Oh group of order 48. In 2D this is the dihedral group D4 of order 8. diff --git a/map2map/models/conv.py b/map2map/models/conv.py index 08d1f6f..7d8fc9f 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn from .narrow import narrow_like -from .swish import Swish class ConvBlock(nn.Module): diff --git a/map2map/models/vnet.py b/map2map/models/vnet.py index f862da8..a2f1fee 100644 --- a/map2map/models/vnet.py +++ b/map2map/models/vnet.py @@ -35,7 +35,10 @@ class VNet(nn.Module): self.up_r0 = ConvBlock(64, seq='UBA') self.conv_r0 = ResBlock(128, out_chan, seq='CAC') - self.bypass = in_chan == out_chan + if bypass is None: + self.bypass = in_chan == out_chan + else: + self.bypass = bypass def forward(self, x): if self.bypass: diff --git a/map2map/test.py b/map2map/test.py index d2b86a1..0eb5149 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -5,6 +5,7 @@ import torch from torch.utils.data import DataLoader from .data import FieldDataset +from .data import norms from . import models from .models import narrow_cast from .utils import import_attr, load_model_state_dict @@ -68,11 +69,13 @@ def test(args): if args.in_norms is not None: start = 0 for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): + norm = import_attr(norm, norms, callback_at=args.callback_at) norm(input[:, start:stop], undo=True) start = stop if args.tgt_norms is not None: start = 0 for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): + norm = import_attr(norm, norms, callback_at=args.callback_at) norm(output[:, start:stop], undo=True) norm(target[:, start:stop], undo=True) start = stop diff --git a/map2map/train.py b/map2map/train.py index cc1af8d..39788b7 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -397,8 +397,7 @@ def dist_init(rank, args): with open(dist_file, mode='w') as f: f.write(args.dist_addr) - - if rank != 0: + else: while not os.path.exists(dist_file): time.sleep(1)