Merge branch 'master' into lag2eul

This commit is contained in:
Yin Li 2020-07-14 18:28:10 -04:00
commit 3437b20ed8
21 changed files with 272 additions and 239 deletions

View File

@ -33,19 +33,22 @@ For all command line options look at `map2map/args.py` or do `m2m.py -h`.
Put each field in one npy file.
Structure your data to start with the channel axis and then the spatial
dimensions.
For example a 2D vector field of size `64^2` should have shape `(2, 64,
64)`.
dimensions, e.g. `(2, 64, 64)` for a 2D vector field of size `64^2` and
`(1, 32, 32, 32)` for a 3D scalar field of size `32^3`.
Specify the data path with
[glob patterns](https://docs.python.org/3/library/glob.html).
During training, pairs of input and target fields are loaded.
Both input and target data can consist of multiple fields, which are
then concatenated along the channel axis.
#### Data cropping
If the size of a pair of input and target fields is too large to fit in
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
Each field can be cropped multiple times, along each dimension,
controlled by the spacing between two adjacent crops (see `--step`).
a GPU, we can crop part of them to form pairs of samples.
Each field can be cropped multiple times, along each dimension.
See `--crop`, `--crop-start`, `--crop-stop`, and `--crop-step`.
The total sample size is the number of input and target pairs multiplied
by the number of cropped samples per pair.

View File

@ -40,39 +40,41 @@ def add_common_args(parser):
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
'of target normalization functions from .data.norms')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
help='size to crop the input and target data. Default is the '
'field size')
parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int,
help='stopping point of the last crop. Default is the opposite '
'corner to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution')
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')
parser.add_argument('--model', required=True, type=str,
parser.add_argument('--model', type=str, required=True,
help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn')
parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. '
'Start from scratch if set empty or the checkpoint is missing')
'Start from scratch in case of empty string or missing checkpoint')
parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states',
dest='load_state_strict')
parser.add_argument('--batches', default=1, type=int,
parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', type=int,
help='number of data loading workers, per GPU in training or '
'in total in testing. '
'Default is the batch size or 0 for batch size 1')
parser.add_argument('--loader-workers', default=-8, type=int,
help='number of subprocesses per data loader. '
'0 to disable multiprocessing; '
'negative number to multiply by the batch size')
parser.add_argument('--cache', action='store_true',
help='enable LRU cache of input and target fields to reduce I/O')
parser.add_argument('--cache-maxsize', type=int,
help='maximum pairs of fields in cache, unlimited by default. '
'This only applies to training if not None, '
'in which case the testing cache maxsize is 1')
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
@ -93,6 +95,10 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int,
help='data augmentation by shifting [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions')
parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, '
'same factor for all fields')
@ -106,8 +112,6 @@ def add_train_args(parser):
help='enable spectral normalization on the adversary model')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversarial criterion from torch.nn')
parser.add_argument('--min-reduction', action='store_true',
help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN')
parser.add_argument('--adv-start', default=0, type=int,
@ -124,7 +128,7 @@ def add_train_args(parser):
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float,
parser.add_argument('--lr', type=float, required=True,
help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float,
# help='momentum')
@ -143,8 +147,6 @@ def add_train_args(parser):
parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training')
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs, useful with cache')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=100, type=int,
@ -175,21 +177,8 @@ def str_list(s):
def set_common_args(args):
if args.loader_workers is None:
args.loader_workers = 0
if args.batches > 1:
args.loader_workers = args.batches
if not args.cache and args.cache_maxsize is not None:
args.cache_maxsize = None
warnings.warn('Resetting cache maxsize given cache is disabled',
RuntimeWarning)
if (args.cache and args.cache_maxsize is not None
and args.cache_maxsize < 1):
args.cache = False
args.cache_maxsize = None
warnings.warn('Disabling cache given cache maxsize < 1',
RuntimeWarning)
if args.loader_workers < 0:
args.loader_workers *= - args.batches
def set_train_args(args):
@ -206,6 +195,11 @@ def set_train_args(args):
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
if args.cgan and not args.adv:
args.cgan =False
warnings.warn('Disabling cgan given adversary is disabled',
RuntimeWarning)
def set_test_args(args):
set_common_args(args)

View File

@ -1,5 +1,4 @@
from glob import glob
from functools import lru_cache
import numpy as np
import torch
import torch.nn.functional as F
@ -14,6 +13,7 @@ class FieldDataset(Dataset):
`in_patterns` is a list of glob patterns for the input field files.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Each pattern in the list is a new field.
Likewise `tgt_patterns` is for target fields.
Input and target fields are matched by sorting the globbed files.
@ -21,29 +21,29 @@ class FieldDataset(Dataset):
Likewise for `tgt_norms`.
Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry known as the Oh point group.
In 3D these form the full octahedral symmetry, the Oh group of order 48.
In 2D this is the dihedral group D4 of order 8.
1D is not supported, but can be done easily by preprocessing.
Fields can be augmented by random shift by a few pixels, useful for models
that treat neighboring pixels differently, e.g. with strided convolutions.
Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor.
Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
Input and target fields can be cropped, to return multiple slices of size
`crop` from each field.
The crop anchors are controlled by `crop_start`, `crop_stop`, and `crop_step`.
Input (but not target) fields can be padded beyond the crop size assuming
periodic boundary condition.
Setting integer `scale_factor` greater than 1 will crop target bigger than
the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution.
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
pairs (unlimited by default).
`div_data` enables data division, to be used with `cache`, so that different
fields are cached in different GPU processes.
This saves CPU RAM but limits stochasticity.
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False,
rank=None, world_size=None):
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -54,12 +54,14 @@ class FieldDataset(Dataset):
'number of input and target fields do not match'
self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found'
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.tgt_files[0]]
self.size = np.load(self.in_files[0][0]).shape[1:]
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
self.size = np.asarray(self.size)
self.ndim = len(self.size)
@ -80,16 +82,35 @@ class FieldDataset(Dataset):
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
self.aug_add = aug_add
self.aug_mul = aug_mul
if crop is None:
self.crop = self.size
self.reps = np.ones_like(self.size)
else:
self.crop = np.broadcast_to(crop, self.size.shape)
self.reps = self.size // self.crop
self.ncrop = int(np.prod(self.reps))
self.crop = np.broadcast_to(crop, (self.ndim,))
if crop_start is None:
crop_start = np.zeros_like(self.size)
else:
crop_start = np.broadcast_to(crop_start, (self.ndim,))
if crop_stop is None:
crop_stop = self.size
else:
crop_stop = np.broadcast_to(crop_stop, (self.ndim,))
if crop_step is None:
crop_step = self.crop
else:
crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.anchors = np.stack(np.mgrid[tuple(
slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim)
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
@ -98,52 +119,25 @@ class FieldDataset(Dataset):
'only support integer upsampling'
self.scale_factor = scale_factor
if cache:
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
if div_data:
self.samples = []
# first add full fields when num_fields > num_GPU
for i in range(rank, self.nfile // world_size * world_size,
world_size):
self.samples.append(
range(i * self.ncrop, (i + 1) * self.ncrop)
)
# then split the rest into fractions of fields
# drop the last incomplete batch of samples
frac_start = self.nfile // world_size * world_size * self.ncrop
frac_samples = self.nfile % world_size * self.ncrop // world_size
self.samples.append(range(frac_start + rank * frac_samples,
frac_start + (rank + 1) * frac_samples))
self.samples = np.concatenate(self.samples)
else:
self.samples = np.arange(self.nfile * self.ncrop)
self.nsample = len(self.samples)
self.rank = rank
def get_fields(self, idx):
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
return in_fields, tgt_fields
def __len__(self):
return self.nsample
return self.nfile * self.ncrop
def __getitem__(self, idx):
idx = self.samples[idx]
ifile, icrop = divmod(idx, self.ncrop)
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop
anchor = self.anchors[icrop]
in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor,
for d, shift in enumerate(self.aug_shift):
if shift is not None:
anchor[d] += torch.randint(shift, (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor,
np.zeros_like(self.pad))
np.zeros_like(self.pad), self.size)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
@ -176,12 +170,21 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields
def crop(fields, start, crop, pad):
def crop(fields, anchor, crop, pad, size):
ndim = len(size)
assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim'
new_fields = []
for x in fields:
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
begin, end = i - p0, i + c + p1
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
ind = [slice(None)]
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
i = np.arange(a - p0, a + c + p1)
i %= s
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
ind.append(i)
x = x[tuple(ind)]
x.setflags(write=True) # workaround numpy bug before 1.18
new_fields.append(x)

View File

@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable
def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
"""Plot slices of fields of more than 2 spatial dimensions.
"""
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
else field for field in fields]
assert all(isinstance(field, np.ndarray) for field in fields)
assert all(field.ndim == fields[0].ndim for field in fields)
nc = max(field.shape[0] for field in fields)
nf = len(fields)
nd = fields[0].ndim - 1
if title is not None:
assert len(title) == nf
@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
norm_ = norm
for c in range(field.shape[0]):
axes[c, f].pcolormesh(field[c, 0, :size, :size],
cmap=cmap_, norm=norm_)
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
axes[c, f].set_aspect('equal')

View File

@ -3,7 +3,10 @@ from .vnet import VNet, VNetFat
from .pyramid import PyramidNet
from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like
from .narrow import narrow_by, narrow_cast, narrow_like
from .resample import resample, Resampler
from .lag2eul import Lag2Eul
from .lag2eul import Lag2Eul

View File

@ -19,7 +19,6 @@ def adv_criterion_wrapper(module):
"""Wrap an adversarial criterion to:
* also take lists of Tensors as target, used to split the input Tensor
along the batch dimension
* enable min reduction on input
* expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors
"""
@ -34,19 +33,10 @@ def adv_criterion_wrapper(module):
input = self.split_input(input, target)
assert len(input) == len(target)
if self.reduction == 'min':
input = [torch.min(i).unsqueeze(0) for i in input]
target = [t.expand_as(i) for i, t in zip(input, target)]
if self.reduction == 'min':
self.reduction = 'mean' # average over batches
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
self.reduction = 'min'
else:
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
return loss

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from .narrow import narrow_like
from .swish import Swish
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
x = self.act(x)
return x
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference,
consistent with the downsampling.
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

View File

@ -1,17 +1,14 @@
from math import log
import torch
class InstanceNoise:
"""Instance noise, with a heuristic annealing schedule
"""Instance noise, with a linear decaying schedule
"""
def __init__(self, init_std, batches):
assert init_std >= 0, 'Noise std cannot be negative'
self.init_std = init_std
self.anneal = 1
self.ln2 = log(2)
self._std = init_std
self.batches = batches
def std(self, adv_loss):
self.anneal -= adv_loss / self.ln2 / self.batches
std = self.anneal * self.init_std
std = std if std > 0 else 0
return std
def std(self):
self._std -= self.init_std / self.batches
return max(self._std, 0)

43
map2map/models/narrow.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
for d in range(2, a.dim()):
a = a.narrow(d, c, a.shape[d] - 2 * c)
return a
def narrow_cast(*tensors):
"""Narrow each tensor to the minimum length in each dimension.
Try to be symmetric but cut more on the right for odd difference
"""
dim_max = max(a.dim() for a in tensors)
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
casted_tensors = []
for a in tensors:
for d in range(2, dim_max):
width = a.shape[d] - len_min[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
casted_tensors.append(a)
return casted_tensors
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

View File

@ -0,0 +1,46 @@
import torch.nn as nn
import torch.nn.functional as F
from .narrow import narrow_by
def resample(x, scale_factor, narrow=True):
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
ndim = x.dim() - 2
mode = modes[ndim]
x = F.interpolate(x, scale_factor=scale_factor,
mode=mode, align_corners=False)
if scale_factor > 1 and narrow == True:
edges = round(scale_factor) // 2
edges = max(edges, 1)
x = narrow_by(x, edges)
return x
class Resampler(nn.Module):
"""Resampling, upsampling or downsampling.
By default discard the inaccurate edges when upsampling.
"""
def __init__(self, ndim, scale_factor, narrow=True):
super().__init__()
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
self.mode = modes[ndim]
self.scale_factor = scale_factor
self.narrow = narrow
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor,
mode=self.mode, align_corners=False)
if self.scale_factor > 1 and self.narrow == True:
edges = round(self.scale_factor) // 2
edges = max(edges, 1)
x = narrow_by(x, edges)
return x

View File

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, narrow_like
from .conv import ConvBlock
from .narrow import narrow_like
class UNet(nn.Module):

View File

@ -1,7 +1,8 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
from .conv import ConvBlock, ResBlock
from .narrow import narrow_like
class VNet(nn.Module):

View File

@ -1,12 +1,12 @@
import sys
from pprint import pprint
import numpy as np
import torch
import sys
from torch.utils.data import DataLoader
from .data import FieldDataset
from . import models
from .models import narrow_like
from .models import narrow_cast
from .utils import import_attr, load_model_state_dict
@ -21,12 +21,15 @@ def test(args):
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,
)
test_loader = DataLoader(
test_dataset,
@ -54,12 +57,7 @@ def test(args):
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
output = model(input)
if args.pad > 0: # FIXME
output = narrow_like(output, target)
input = narrow_like(input, target)
else:
target = narrow_like(target, output)
input = narrow_like(input, output)
input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target)

View File

@ -15,9 +15,9 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler
from .data.figures import fig3d
from .data.figures import plt_slices
from . import models
from .models import (narrow_like, Lag2Eul,
from .models import (narrow_like, resample, Lag2Eul,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
@ -65,28 +65,17 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,
cache_maxsize=args.cache_maxsize,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
)
if args.div_data:
train_sampler = GroupedRandomSampler(
train_dataset,
group_size=None if args.cache_maxsize is None else
args.cache_maxsize * train_dataset.ncrop,
)
else:
try:
train_sampler = DistributedSampler(train_dataset, shuffle=True)
except TypeError:
train_sampler = DistributedSampler(train_dataset) # old pytorch
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
@ -104,24 +93,17 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,
cache_maxsize=None if args.cache_maxsize is None else 1,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
)
if args.div_data:
val_sampler = None
else:
try:
val_sampler = DistributedSampler(val_dataset, shuffle=False)
except TypeError:
val_sampler = DistributedSampler(val_dataset) # old pytorch
val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,
@ -170,7 +152,7 @@ def gpu_worker(local_rank, node, args):
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion = adv_criterion()
adv_criterion.to(device)
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
@ -246,8 +228,7 @@ def gpu_worker(local_rank, node, args):
args.instance_noise_batches)
for epoch in range(start_epoch, args.epochs):
if not args.div_data:
train_sampler.set_epoch(epoch)
train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader,
model, dis2den, criterion, optimizer, scheduler,
@ -267,10 +248,7 @@ def gpu_worker(local_rank, node, args):
adv_scheduler.step(epoch_loss[0])
if rank == 0:
try:
logger.flush()
except AttributeError:
logger.close() # old pytorch
logger.flush()
if ((min_loss is None or epoch_loss[0] < min_loss[0])
and epoch >= args.adv_start):
@ -293,12 +271,6 @@ def gpu_worker(local_rank, node, args):
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)
if args.cache:
print('rank {} train data: {}'.format(
rank, train_dataset.get_fields.cache_info()))
print('rank {} val data: {}'.format(
rank, val_dataset.get_fields.cache_info()))
dist.destroy_process_group()
@ -327,12 +299,15 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
target = target.to(device, non_blocking=True)
output = model(input)
if epoch == 0 and i == 0 and rank == 0:
print('input.shape =', input.shape)
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
target = narrow_like(target, output) # FIXME pad
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
output, target = dis2den(output, target)
@ -340,13 +315,11 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start:
try:
noise_std = args.instance_noise.std(adv_loss)
except NameError:
noise_std = args.instance_noise.std(0)
noise_std = args.instance_noise.std()
if noise_std > 0:
noise = noise_std * torch.randn_like(output)
output = output + noise.detach()
noise = noise_std * torch.randn_like(target)
target = target + noise.detach()
del noise
@ -405,21 +378,19 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad', {
'first': grads[0],
'last': grads[-1],
}, global_step=batch)
logger.add_scalar('grad/first', grads[0], global_step=batch)
logger.add_scalar('grad/last', grads[-1], global_step=batch)
if args.adv and epoch >= args.adv_start:
grads = list(p.grad for n, p in adv_model.named_parameters()
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad/adv', {
'first': grads[0],
'last': grads[-1],
}, global_step=batch)
logger.add_scalars('grad/adv/first', grads[0],
global_step=batch)
logger.add_scalars('grad/adv/last', grads[-1],
global_step=batch)
if args.adv and epoch >= args.adv_start:
if args.adv and epoch >= args.adv_start and noise_std > 0:
logger.add_scalar('instance_noise', noise_std,
global_step=batch)
@ -440,7 +411,7 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/train', fig3d(
logger.add_figure('fig/epoch/train', plt_slices(
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
@ -471,11 +442,10 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
output = model(input)
target = narrow_like(target, output) # FIXME pad
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
output, target = dis2den(output, target)
@ -517,7 +487,7 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/val', fig3d(
logger.add_figure('fig/epoch/val', plt_slices(
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],

View File

@ -38,8 +38,7 @@ srun m2m.py train \
--in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \
--model UNet \
--lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
--epochs 1024 --seed $RANDOM
date

View File

@ -38,8 +38,7 @@ m2m.py test \
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
--model VNet \
--load-state best_model.pt \
--batches 1 --loader-workers 0 \
--cache
--batches 1 --loader-workers 0
date

View File

@ -39,8 +39,7 @@ srun m2m.py train \
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
--model VNet --adv-model UNet --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
--epochs 1024 --seed $RANDOM
date

View File

@ -39,10 +39,7 @@ srun m2m.py train \
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
--model VNet --adv-model PatchGAN --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
--epochs 1024 --seed $RANDOM
date

View File

@ -38,8 +38,7 @@ m2m.py test \
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
--model VNet \
--load-state best_model.pt \
--batches 1 --loader-workers 0 \
--cache
--batches 1 --loader-workers 0
date

View File

@ -39,8 +39,7 @@ srun m2m.py train \
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
--model VNet --adv-model UNet --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
--epochs 1024 --seed $RANDOM
date

View File

@ -10,7 +10,7 @@ setup(
packages=find_packages(),
python_requires='>=3.6',
install_requires=[
'torch',
'torch>=1.2',
'numpy',
'scipy',
],