Add separate input and target normalization
This commit is contained in:
parent
46a3a3a97d
commit
c68b9928ee
@ -16,8 +16,10 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def add_common_args(parser):
|
def add_common_args(parser):
|
||||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list '
|
||||||
'of normalization functions from .data.norms')
|
'of input normalization functions from .data.norms')
|
||||||
|
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||||
|
'of target normalization functions from .data.norms')
|
||||||
parser.add_argument('--crop', type=int,
|
parser.add_argument('--crop', type=int,
|
||||||
help='size to crop the input and target data')
|
help='size to crop the input and target data')
|
||||||
parser.add_argument('--pad', default=0, type=int,
|
parser.add_argument('--pad', default=0, type=int,
|
||||||
|
@ -14,7 +14,8 @@ class FieldDataset(Dataset):
|
|||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target samples are matched by sorting the globbed files.
|
Input and target samples are matched by sorting the globbed files.
|
||||||
|
|
||||||
`norms` can be a list of callables to normalize each field.
|
`in_norms` is a list of of functions to normalize the input fields.
|
||||||
|
Likewise for `tgt_norms`.
|
||||||
|
|
||||||
Data augmentations are supported for scalar and vector fields.
|
Data augmentations are supported for scalar and vector fields.
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ class FieldDataset(Dataset):
|
|||||||
`div_data` enables data division, useful when combined with caching.
|
`div_data` enables data division, useful when combined with caching.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
norms=None, augment=False, crop=None, pad=0,
|
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0,
|
||||||
cache=False, div_data=False, rank=None, world_size=None,
|
cache=False, div_data=False, rank=None, world_size=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
@ -44,18 +45,24 @@ class FieldDataset(Dataset):
|
|||||||
self.in_files = self.in_files[rank * files : (rank + 1) * files]
|
self.in_files = self.in_files[rank * files : (rank + 1) * files]
|
||||||
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
|
self.tgt_files = self.tgt_files[rank * files : (rank + 1) * files]
|
||||||
|
|
||||||
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
||||||
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
||||||
|
|
||||||
self.size = np.load(self.in_files[0][0]).shape[1:]
|
self.size = np.load(self.in_files[0][0]).shape[1:]
|
||||||
self.size = np.asarray(self.size)
|
self.size = np.asarray(self.size)
|
||||||
self.ndim = len(self.size)
|
self.ndim = len(self.size)
|
||||||
|
|
||||||
if norms is not None: # FIXME: in_norms, tgt_norms
|
if in_norms is not None:
|
||||||
assert len(in_patterns) == len(norms), \
|
assert len(in_patterns) == len(in_norms), \
|
||||||
'numbers of normalization callables and input fields do not match'
|
'numbers of input normalization functions and fields do not match'
|
||||||
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
in_norms = [import_norm(norm) for norm in in_norms]
|
||||||
self.norms = norms
|
self.in_norms = in_norms
|
||||||
|
|
||||||
|
if tgt_norms is not None:
|
||||||
|
assert len(tgt_patterns) == len(tgt_norms), \
|
||||||
|
'numbers of target normalization functions and fields do not match'
|
||||||
|
tgt_norms = [import_norm(norm) for norm in tgt_norms]
|
||||||
|
self.tgt_norms = tgt_norms
|
||||||
|
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
if self.ndim == 1 and self.augment:
|
if self.ndim == 1 and self.augment:
|
||||||
@ -80,19 +87,9 @@ class FieldDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.in_files) * self.tot_reps
|
return len(self.in_files) * self.tot_reps
|
||||||
|
|
||||||
@property
|
|
||||||
def channels(self):
|
|
||||||
return self.in_channels, self.tgt_channels
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
||||||
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
||||||
#print('==================================================')
|
|
||||||
#print(f'idx = {idx}, sub_idx = {sub_idx}, start = {start}')
|
|
||||||
#print(f'self.reps = {self.reps}, self.tot_reps = {self.tot_reps}')
|
|
||||||
#print(f'self.crop = {self.crop}, self.size = {self.size}')
|
|
||||||
#print(f'self.ndim = {self.ndim}, self.channels = {self.channels}')
|
|
||||||
#print(f'self.pad = {self.pad}')
|
|
||||||
|
|
||||||
if self.cache:
|
if self.cache:
|
||||||
try:
|
try:
|
||||||
@ -125,10 +122,12 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = perm(in_fields, perm_axes, self.ndim)
|
in_fields = perm(in_fields, perm_axes, self.ndim)
|
||||||
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
|
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
|
||||||
|
|
||||||
if self.norms is not None:
|
if self.in_norms is not None:
|
||||||
for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
|
for norm, x in zip(self.in_norms, in_fields):
|
||||||
norm(ifield)
|
norm(x)
|
||||||
norm(tfield)
|
if self.tgt_norms is not None:
|
||||||
|
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||||
|
norm(x)
|
||||||
|
|
||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
@ -3,8 +3,11 @@ from importlib import import_module
|
|||||||
from . import cosmology
|
from . import cosmology
|
||||||
|
|
||||||
|
|
||||||
def import_norm(path):
|
def import_norm(norm):
|
||||||
mod, fun = path.rsplit('.', 1)
|
if callable(norm):
|
||||||
|
return norm
|
||||||
|
|
||||||
|
mod, fun = norm.rsplit('.', 1)
|
||||||
mod = import_module('.' + mod, __name__)
|
mod = import_module('.' + mod, __name__)
|
||||||
fun = getattr(mod, fun)
|
fun = getattr(mod, fun)
|
||||||
return fun
|
return fun
|
||||||
|
@ -7,20 +7,20 @@ from .swish import Swish
|
|||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolution blocks of the form specified by `seq`.
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
kernel_size=3, seq='CBA'):
|
kernel_size=3, seq='CBA'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if out_channels is None:
|
if out_chan is None:
|
||||||
out_channels = in_channels
|
out_chan = in_chan
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_chan = in_chan
|
||||||
self.out_channels = out_channels
|
self.out_chan = out_chan
|
||||||
if mid_channels is None:
|
if mid_chan is None:
|
||||||
self.mid_channels = max(in_channels, out_channels)
|
self.mid_chan = max(in_chan, out_chan)
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
self.bn_channels = in_channels
|
self.norm_chan = in_chan
|
||||||
self.idx_conv = 0
|
self.idx_conv = 0
|
||||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||||
|
|
||||||
@ -30,18 +30,18 @@ class ConvBlock(nn.Module):
|
|||||||
|
|
||||||
def _get_layer(self, l):
|
def _get_layer(self, l):
|
||||||
if l == 'U':
|
if l == 'U':
|
||||||
in_channels, out_channels = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.ConvTranspose3d(in_channels, out_channels, 2, stride=2)
|
return nn.ConvTranspose3d(in_chan, out_chan, 2, stride=2)
|
||||||
elif l == 'D':
|
elif l == 'D':
|
||||||
in_channels, out_channels = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
|
||||||
elif l == 'C':
|
elif l == 'C':
|
||||||
in_channels, out_channels = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.Conv3d(in_channels, out_channels, self.kernel_size)
|
return nn.Conv3d(in_chan, out_chan, self.kernel_size)
|
||||||
elif l == 'B':
|
elif l == 'B':
|
||||||
#return nn.BatchNorm3d(self.bn_channels)
|
#return nn.BatchNorm3d(self.norm_chan)
|
||||||
#return nn.InstanceNorm3d(self.bn_channels, affine=True, track_running_stats=True)
|
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
||||||
return nn.InstanceNorm3d(self.bn_channels)
|
return nn.InstanceNorm3d(self.norm_chan)
|
||||||
elif l == 'A':
|
elif l == 'A':
|
||||||
return nn.LeakyReLU()
|
return nn.LeakyReLU()
|
||||||
else:
|
else:
|
||||||
@ -50,15 +50,15 @@ class ConvBlock(nn.Module):
|
|||||||
def _setup_conv(self):
|
def _setup_conv(self):
|
||||||
self.idx_conv += 1
|
self.idx_conv += 1
|
||||||
|
|
||||||
in_channels = out_channels = self.mid_channels
|
in_chan = out_chan = self.mid_chan
|
||||||
if self.idx_conv == 1:
|
if self.idx_conv == 1:
|
||||||
in_channels = self.in_channels
|
in_chan = self.in_chan
|
||||||
if self.idx_conv == self.num_conv:
|
if self.idx_conv == self.num_conv:
|
||||||
out_channels = self.out_channels
|
out_chan = self.out_chan
|
||||||
|
|
||||||
self.bn_channels = out_channels
|
self.norm_chan = out_chan
|
||||||
|
|
||||||
return in_channels, out_channels
|
return in_chan, out_chan
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.convs(x)
|
return self.convs(x)
|
||||||
@ -68,16 +68,16 @@ class ResBlock(ConvBlock):
|
|||||||
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
||||||
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
seq='CBACBA'):
|
seq='CBACBA'):
|
||||||
super().__init__(in_channels, out_channels=out_channels,
|
super().__init__(in_chan, out_chan=out_chan,
|
||||||
mid_channels=mid_channels,
|
mid_chan=mid_chan,
|
||||||
seq=seq[:-1] if seq[-1] == 'A' else seq)
|
seq=seq[:-1] if seq[-1] == 'A' else seq)
|
||||||
|
|
||||||
if out_channels is None:
|
if out_chan is None:
|
||||||
self.skip = None
|
self.skip = None
|
||||||
else:
|
else:
|
||||||
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
self.skip = nn.Conv3d(in_chan, out_chan, 1)
|
||||||
|
|
||||||
if 'U' in seq or 'D' in seq:
|
if 'U' in seq or 'D' in seq:
|
||||||
raise NotImplementedError('upsample and downsample layers '
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
|
@ -5,10 +5,10 @@ from .conv import ConvBlock, narrow_like
|
|||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_chan, out_chan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_l0 = ConvBlock(in_channels, 64, seq='CAC')
|
self.conv_l0 = ConvBlock(in_chan, 64, seq='CAC')
|
||||||
self.down_l0 = ConvBlock(64, seq='BADBA')
|
self.down_l0 = ConvBlock(64, seq='BADBA')
|
||||||
self.conv_l1 = ConvBlock(64, seq='CBAC')
|
self.conv_l1 = ConvBlock(64, seq='CBAC')
|
||||||
self.down_l1 = ConvBlock(64, seq='BADBA')
|
self.down_l1 = ConvBlock(64, seq='BADBA')
|
||||||
@ -18,7 +18,7 @@ class UNet(nn.Module):
|
|||||||
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_r1 = ConvBlock(128, 64, seq='CBAC')
|
self.conv_r1 = ConvBlock(128, 64, seq='CBAC')
|
||||||
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_r0 = ConvBlock(128, out_channels, seq='CAC')
|
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y0 = self.conv_l0(x)
|
y0 = self.conv_l0(x)
|
||||||
|
@ -5,10 +5,10 @@ from .conv import ConvBlock, ResBlock, narrow_like
|
|||||||
|
|
||||||
|
|
||||||
class VNet(nn.Module):
|
class VNet(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_chan, out_chan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_l0 = ResBlock(in_channels, 64, seq='CAC')
|
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC')
|
||||||
self.down_l0 = ConvBlock(64, seq='BADBA')
|
self.down_l0 = ConvBlock(64, seq='BADBA')
|
||||||
self.conv_l1 = ResBlock(64, seq='CBAC')
|
self.conv_l1 = ResBlock(64, seq='CBAC')
|
||||||
self.down_l1 = ConvBlock(64, seq='BADBA')
|
self.down_l1 = ConvBlock(64, seq='BADBA')
|
||||||
@ -18,7 +18,7 @@ class VNet(nn.Module):
|
|||||||
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
|
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
|
||||||
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_r0 = ResBlock(128, out_channels, seq='CAC')
|
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y0 = self.conv_l0(x)
|
y0 = self.conv_l0(x)
|
||||||
@ -45,11 +45,11 @@ class VNet(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class VNetFat(nn.Module):
|
class VNetFat(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_chan, out_chan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_l0 = nn.Sequential(
|
self.conv_l0 = nn.Sequential(
|
||||||
ResBlock(in_channels, 64, seq='CACBA'),
|
ResBlock(in_chan, 64, seq='CACBA'),
|
||||||
ResBlock(64, seq='CBACBA'),
|
ResBlock(64, seq='CBACBA'),
|
||||||
)
|
)
|
||||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||||
@ -72,7 +72,7 @@ class VNetFat(nn.Module):
|
|||||||
self.up_r0 = ConvBlock(128, 64, seq='UBA')
|
self.up_r0 = ConvBlock(128, 64, seq='UBA')
|
||||||
self.conv_r0 = nn.Sequential(
|
self.conv_r0 = nn.Sequential(
|
||||||
ResBlock(128, seq='CBACBA'),
|
ResBlock(128, seq='CBACBA'),
|
||||||
ResBlock(128, out_channels, seq='CAC'),
|
ResBlock(128, out_chan, seq='CAC'),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -23,10 +23,10 @@ def test(args):
|
|||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
in_channels, out_channels = test_dataset.channels
|
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(in_channels, out_channels)
|
model = model(sum(in_chan), sum(out_chan))
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
@ -53,11 +53,17 @@ def test(args):
|
|||||||
|
|
||||||
print('sample {} loss: {}'.format(i, loss.item()))
|
print('sample {} loss: {}'.format(i, loss.item()))
|
||||||
|
|
||||||
if args.norms is not None:
|
if args.in_norms is not None:
|
||||||
norm = test_dataset.norms[0] # FIXME
|
start = 0
|
||||||
norm(input, undo=True)
|
for norm, stop in zip(test_dataset.in_norms, np.cumsum(in_chan)):
|
||||||
norm(output, undo=True)
|
norm(input[:, start:stop], undo=True)
|
||||||
norm(target, undo=True)
|
start = stop
|
||||||
|
if args.tgt_norms is not None:
|
||||||
|
start = 0
|
||||||
|
for norm, stop in zip(test_dataset.tgt_norms, np.cumsum(out_chan)):
|
||||||
|
norm(output[:, start:stop], undo=True)
|
||||||
|
norm(target[:, start:stop], undo=True)
|
||||||
|
start = stop
|
||||||
|
|
||||||
np.savez('{}.npz'.format(i), input=input.numpy(),
|
np.savez('{}.npz'.format(i), input=input.numpy(),
|
||||||
output=output.numpy(), target=target.numpy())
|
output=output.numpy(), target=target.numpy())
|
||||||
|
@ -81,10 +81,10 @@ def gpu_worker(local_rank, args):
|
|||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
in_channels, out_channels = train_dataset.channels
|
in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(in_channels, out_channels)
|
model = model(sum(in_chan), sum(out_chan))
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device],
|
model = DistributedDataParallel(model, device_ids=[args.device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
@ -108,8 +108,8 @@ def gpu_worker(local_rank, args):
|
|||||||
args.adv = args.adv_model is not None
|
args.adv = args.adv_model is not None
|
||||||
if args.adv:
|
if args.adv:
|
||||||
adv_model = getattr(models, args.adv_model)
|
adv_model = getattr(models, args.adv_model)
|
||||||
adv_model = adv_model(in_channels + out_channels
|
adv_model = adv_model(sum(in_chan + out_chan)
|
||||||
if args.cgan else out_channels, 1)
|
if args.cgan else sum(out_chan), 1)
|
||||||
adv_model.to(args.device)
|
adv_model.to(args.device)
|
||||||
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
|
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
@ -35,7 +35,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --crop 256 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pth \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
|
@ -40,7 +40,7 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --augment --crop 128 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
|
||||||
--model VNet --adv-model UNet --cgan \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 128 --seed $RANDOM \
|
--epochs 128 --seed $RANDOM \
|
||||||
|
@ -35,7 +35,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --crop 256 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pth \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
|
@ -40,7 +40,7 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --augment --crop 128 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
|
||||||
--model VNet --adv-model UNet --cgan \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 128 --seed $RANDOM \
|
--epochs 128 --seed $RANDOM \
|
||||||
|
Loading…
Reference in New Issue
Block a user