Fixes and cleaning up
This commit is contained in:
parent
55b1a72ef4
commit
8544ff07e8
@ -20,6 +20,8 @@ class FieldDataset(Dataset):
|
|||||||
`in_norms` is a list of of functions to normalize the input fields.
|
`in_norms` is a list of of functions to normalize the input fields.
|
||||||
Likewise for `tgt_norms`.
|
Likewise for `tgt_norms`.
|
||||||
|
|
||||||
|
NOTE that vector fields are assumed if numbers of channels and dimensions are equal.
|
||||||
|
|
||||||
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||||
In 3D these form the full octahedral symmetry, the Oh group of order 48.
|
In 3D these form the full octahedral symmetry, the Oh group of order 48.
|
||||||
In 2D this is the dihedral group D4 of order 8.
|
In 2D this is the dihedral group D4 of order 8.
|
||||||
|
@ -3,7 +3,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .narrow import narrow_like
|
from .narrow import narrow_like
|
||||||
from .swish import Swish
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
|
@ -35,7 +35,10 @@ class VNet(nn.Module):
|
|||||||
self.up_r0 = ConvBlock(64, seq='UBA')
|
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||||
|
|
||||||
self.bypass = in_chan == out_chan
|
if bypass is None:
|
||||||
|
self.bypass = in_chan == out_chan
|
||||||
|
else:
|
||||||
|
self.bypass = bypass
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.bypass:
|
if self.bypass:
|
||||||
|
@ -5,6 +5,7 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
|
from .data import norms
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_cast
|
from .models import narrow_cast
|
||||||
from .utils import import_attr, load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
@ -68,11 +69,13 @@ def test(args):
|
|||||||
if args.in_norms is not None:
|
if args.in_norms is not None:
|
||||||
start = 0
|
start = 0
|
||||||
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||||
|
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
norm(input[:, start:stop], undo=True)
|
norm(input[:, start:stop], undo=True)
|
||||||
start = stop
|
start = stop
|
||||||
if args.tgt_norms is not None:
|
if args.tgt_norms is not None:
|
||||||
start = 0
|
start = 0
|
||||||
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||||
|
norm = import_attr(norm, norms, callback_at=args.callback_at)
|
||||||
norm(output[:, start:stop], undo=True)
|
norm(output[:, start:stop], undo=True)
|
||||||
norm(target[:, start:stop], undo=True)
|
norm(target[:, start:stop], undo=True)
|
||||||
start = stop
|
start = stop
|
||||||
|
@ -397,8 +397,7 @@ def dist_init(rank, args):
|
|||||||
|
|
||||||
with open(dist_file, mode='w') as f:
|
with open(dist_file, mode='w') as f:
|
||||||
f.write(args.dist_addr)
|
f.write(args.dist_addr)
|
||||||
|
else:
|
||||||
if rank != 0:
|
|
||||||
while not os.path.exists(dist_file):
|
while not os.path.exists(dist_file):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user