Fixes and cleaning up

This commit is contained in:
Yin Li 2021-03-18 14:16:17 -04:00
parent 55b1a72ef4
commit 8544ff07e8
5 changed files with 10 additions and 4 deletions

View File

@ -20,6 +20,8 @@ class FieldDataset(Dataset):
`in_norms` is a list of of functions to normalize the input fields. `in_norms` is a list of of functions to normalize the input fields.
Likewise for `tgt_norms`. Likewise for `tgt_norms`.
NOTE that vector fields are assumed if numbers of channels and dimensions are equal.
Scalar and vector fields can be augmented by flipping and permutating the axes. Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry, the Oh group of order 48. In 3D these form the full octahedral symmetry, the Oh group of order 48.
In 2D this is the dihedral group D4 of order 8. In 2D this is the dihedral group D4 of order 8.

View File

@ -3,7 +3,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from .narrow import narrow_like from .narrow import narrow_like
from .swish import Swish
class ConvBlock(nn.Module): class ConvBlock(nn.Module):

View File

@ -35,7 +35,10 @@ class VNet(nn.Module):
self.up_r0 = ConvBlock(64, seq='UBA') self.up_r0 = ConvBlock(64, seq='UBA')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC') self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
self.bypass = in_chan == out_chan if bypass is None:
self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x): def forward(self, x):
if self.bypass: if self.bypass:

View File

@ -5,6 +5,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .data import FieldDataset from .data import FieldDataset
from .data import norms
from . import models from . import models
from .models import narrow_cast from .models import narrow_cast
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
@ -68,11 +69,13 @@ def test(args):
if args.in_norms is not None: if args.in_norms is not None:
start = 0 start = 0
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)): for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(input[:, start:stop], undo=True) norm(input[:, start:stop], undo=True)
start = stop start = stop
if args.tgt_norms is not None: if args.tgt_norms is not None:
start = 0 start = 0
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)): for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
norm = import_attr(norm, norms, callback_at=args.callback_at)
norm(output[:, start:stop], undo=True) norm(output[:, start:stop], undo=True)
norm(target[:, start:stop], undo=True) norm(target[:, start:stop], undo=True)
start = stop start = stop

View File

@ -397,8 +397,7 @@ def dist_init(rank, args):
with open(dist_file, mode='w') as f: with open(dist_file, mode='w') as f:
f.write(args.dist_addr) f.write(args.dist_addr)
else:
if rank != 0:
while not os.path.exists(dist_file): while not os.path.exists(dist_file):
time.sleep(1) time.sleep(1)