Merge branch 'master' into lag2eul

This commit is contained in:
Yin Li 2020-07-14 18:28:10 -04:00
commit 3437b20ed8
21 changed files with 272 additions and 239 deletions

View File

@ -33,19 +33,22 @@ For all command line options look at `map2map/args.py` or do `m2m.py -h`.
Put each field in one npy file. Put each field in one npy file.
Structure your data to start with the channel axis and then the spatial Structure your data to start with the channel axis and then the spatial
dimensions. dimensions, e.g. `(2, 64, 64)` for a 2D vector field of size `64^2` and
For example a 2D vector field of size `64^2` should have shape `(2, 64, `(1, 32, 32, 32)` for a 3D scalar field of size `32^3`.
64)`.
Specify the data path with Specify the data path with
[glob patterns](https://docs.python.org/3/library/glob.html). [glob patterns](https://docs.python.org/3/library/glob.html).
During training, pairs of input and target fields are loaded. During training, pairs of input and target fields are loaded.
Both input and target data can consist of multiple fields, which are Both input and target data can consist of multiple fields, which are
then concatenated along the channel axis. then concatenated along the channel axis.
#### Data cropping
If the size of a pair of input and target fields is too large to fit in If the size of a pair of input and target fields is too large to fit in
a GPU, we can crop part of them to form pairs of samples (see `--crop`). a GPU, we can crop part of them to form pairs of samples.
Each field can be cropped multiple times, along each dimension, Each field can be cropped multiple times, along each dimension.
controlled by the spacing between two adjacent crops (see `--step`). See `--crop`, `--crop-start`, `--crop-stop`, and `--crop-step`.
The total sample size is the number of input and target pairs multiplied The total sample size is the number of input and target pairs multiplied
by the number of cropped samples per pair. by the number of cropped samples per pair.

View File

@ -40,39 +40,41 @@ def add_common_args(parser):
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list ' parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
'of target normalization functions from .data.norms') 'of target normalization functions from .data.norms')
parser.add_argument('--crop', type=int, parser.add_argument('--crop', type=int,
help='size to crop the input and target data') help='size to crop the input and target data. Default is the '
'field size')
parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int,
help='stopping point of the last crop. Default is the opposite '
'corner to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int, parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming ' help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition') 'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int, parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in ' help='upsampling factor for super-resolution, in which case '
'which case crop and pad will be taken at the original resolution') 'crop and pad are sizes of the input resolution')
parser.add_argument('--model', required=True, type=str, parser.add_argument('--model', type=str, required=True,
help='model from .models') help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str, parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default=ckpt_link, type=str, parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc. ' help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. ' 'Default is the checkpoint. '
'Start from scratch if set empty or the checkpoint is missing') 'Start from scratch in case of empty string or missing checkpoint')
parser.add_argument('--load-state-non-strict', action='store_false', parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states', help='allow incompatible keys when loading model states',
dest='load_state_strict') dest='load_state_strict')
parser.add_argument('--batches', default=1, type=int, parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', type=int, parser.add_argument('--loader-workers', default=-8, type=int,
help='number of data loading workers, per GPU in training or ' help='number of subprocesses per data loader. '
'in total in testing. ' '0 to disable multiprocessing; '
'Default is the batch size or 0 for batch size 1') 'negative number to multiply by the batch size')
parser.add_argument('--cache', action='store_true',
help='enable LRU cache of input and target fields to reduce I/O')
parser.add_argument('--cache-maxsize', type=int,
help='maximum pairs of fields in cache, unlimited by default. '
'This only applies to training if not None, '
'in which case the testing cache maxsize is 1')
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s), parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, ' help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. ' 'norms, criteria, and optimizers. Disabled if not set. '
@ -93,6 +95,10 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation target data') help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true', parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation') help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int,
help='data augmentation by shifting [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions')
parser.add_argument('--aug-add', type=float, parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, ' help='additive data augmentation, (normal) std, '
'same factor for all fields') 'same factor for all fields')
@ -106,8 +112,6 @@ def add_train_args(parser):
help='enable spectral normalization on the adversary model') help='enable spectral normalization on the adversary model')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str, parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversarial criterion from torch.nn') help='adversarial criterion from torch.nn')
parser.add_argument('--min-reduction', action='store_true',
help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true', parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN') help='enable conditional GAN')
parser.add_argument('--adv-start', default=0, type=int, parser.add_argument('--adv-start', default=0, type=int,
@ -124,7 +128,7 @@ def add_train_args(parser):
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float, parser.add_argument('--lr', type=float, required=True,
help='initial learning rate') help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float, # parser.add_argument('--momentum', default=0.9, type=float,
# help='momentum') # help='momentum')
@ -143,8 +147,6 @@ def add_train_args(parser):
parser.add_argument('--seed', default=42, type=int, parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training') help='seed for initializing training')
parser.add_argument('--div-data', action='store_true',
help='enable data division among GPUs, useful with cache')
parser.add_argument('--dist-backend', default='nccl', type=str, parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend') choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=100, type=int, parser.add_argument('--log-interval', default=100, type=int,
@ -175,21 +177,8 @@ def str_list(s):
def set_common_args(args): def set_common_args(args):
if args.loader_workers is None: if args.loader_workers < 0:
args.loader_workers = 0 args.loader_workers *= - args.batches
if args.batches > 1:
args.loader_workers = args.batches
if not args.cache and args.cache_maxsize is not None:
args.cache_maxsize = None
warnings.warn('Resetting cache maxsize given cache is disabled',
RuntimeWarning)
if (args.cache and args.cache_maxsize is not None
and args.cache_maxsize < 1):
args.cache = False
args.cache_maxsize = None
warnings.warn('Disabling cache given cache maxsize < 1',
RuntimeWarning)
def set_train_args(args): def set_train_args(args):
@ -206,6 +195,11 @@ def set_train_args(args):
if args.adv_weight_decay is None: if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay args.adv_weight_decay = args.weight_decay
if args.cgan and not args.adv:
args.cgan =False
warnings.warn('Disabling cgan given adversary is disabled',
RuntimeWarning)
def set_test_args(args): def set_test_args(args):
set_common_args(args) set_common_args(args)

View File

@ -1,5 +1,4 @@
from glob import glob from glob import glob
from functools import lru_cache
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -14,6 +13,7 @@ class FieldDataset(Dataset):
`in_patterns` is a list of glob patterns for the input field files. `in_patterns` is a list of glob patterns for the input field files.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`. For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Each pattern in the list is a new field.
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target fields are matched by sorting the globbed files. Input and target fields are matched by sorting the globbed files.
@ -21,29 +21,29 @@ class FieldDataset(Dataset):
Likewise for `tgt_norms`. Likewise for `tgt_norms`.
Scalar and vector fields can be augmented by flipping and permutating the axes. Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry known as the Oh point group. In 3D these form the full octahedral symmetry, the Oh group of order 48.
In 2D this is the dihedral group D4 of order 8.
1D is not supported, but can be done easily by preprocessing.
Fields can be augmented by random shift by a few pixels, useful for models
that treat neighboring pixels differently, e.g. with strided convolutions.
Additive and multiplicative augmentation are also possible, but with all fields Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor. added or multiplied by the same factor.
Input and target fields can be cropped. Input and target fields can be cropped, to return multiple slices of size
Input fields can be padded assuming periodic boundary condition. `crop` from each field.
The crop anchors are controlled by `crop_start`, `crop_stop`, and `crop_step`.
Input (but not target) fields can be padded beyond the crop size assuming
periodic boundary condition.
Setting integer `scale_factor` greater than 1 will crop target bigger than Setting integer `scale_factor` greater than 1 will crop target bigger than
the input for super-resolution, in which case `crop` and `pad` are sizes of the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution. the input resolution.
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
pairs (unlimited by default).
`div_data` enables data division, to be used with `cache`, so that different
fields are cached in different GPU processes.
This saves CPU RAM but limits stochasticity.
""" """
def __init__(self, in_patterns, tgt_patterns, def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None, in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1, crop=None, crop_start=None, crop_stop=None, crop_step=None,
cache=False, cache_maxsize=None, div_data=False, pad=0, scale_factor=1):
rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))
@ -54,12 +54,14 @@ class FieldDataset(Dataset):
'number of input and target fields do not match' 'number of input and target fields do not match'
self.nfile = len(self.in_files) self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found' assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]] self.in_chan = [np.load(f, mmap_mode='r').shape[0]
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]] for f in self.in_files[0]]
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
for f in self.tgt_files[0]]
self.size = np.load(self.in_files[0][0]).shape[1:] self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
self.size = np.asarray(self.size) self.size = np.asarray(self.size)
self.ndim = len(self.size) self.ndim = len(self.size)
@ -80,16 +82,35 @@ class FieldDataset(Dataset):
self.augment = augment self.augment = augment
if self.ndim == 1 and self.augment: if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields') raise ValueError('cannot augment 1D fields')
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
self.aug_add = aug_add self.aug_add = aug_add
self.aug_mul = aug_mul self.aug_mul = aug_mul
if crop is None: if crop is None:
self.crop = self.size self.crop = self.size
self.reps = np.ones_like(self.size)
else: else:
self.crop = np.broadcast_to(crop, self.size.shape) self.crop = np.broadcast_to(crop, (self.ndim,))
self.reps = self.size // self.crop
self.ncrop = int(np.prod(self.reps)) if crop_start is None:
crop_start = np.zeros_like(self.size)
else:
crop_start = np.broadcast_to(crop_start, (self.ndim,))
if crop_stop is None:
crop_stop = self.size
else:
crop_stop = np.broadcast_to(crop_stop, (self.ndim,))
if crop_step is None:
crop_step = self.crop
else:
crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.anchors = np.stack(np.mgrid[tuple(
slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim)
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(pad, int), 'only support symmetric padding for now' assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2)) self.pad = np.broadcast_to(pad, (self.ndim, 2))
@ -98,52 +119,25 @@ class FieldDataset(Dataset):
'only support integer upsampling' 'only support integer upsampling'
self.scale_factor = scale_factor self.scale_factor = scale_factor
if cache:
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
if div_data:
self.samples = []
# first add full fields when num_fields > num_GPU
for i in range(rank, self.nfile // world_size * world_size,
world_size):
self.samples.append(
range(i * self.ncrop, (i + 1) * self.ncrop)
)
# then split the rest into fractions of fields
# drop the last incomplete batch of samples
frac_start = self.nfile // world_size * world_size * self.ncrop
frac_samples = self.nfile % world_size * self.ncrop // world_size
self.samples.append(range(frac_start + rank * frac_samples,
frac_start + (rank + 1) * frac_samples))
self.samples = np.concatenate(self.samples)
else:
self.samples = np.arange(self.nfile * self.ncrop)
self.nsample = len(self.samples)
self.rank = rank
def get_fields(self, idx):
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
return in_fields, tgt_fields
def __len__(self): def __len__(self):
return self.nsample return self.nfile * self.ncrop
def __getitem__(self, idx): def __getitem__(self, idx):
idx = self.samples[idx] ifile, icrop = divmod(idx, self.ncrop)
in_fields, tgt_fields = self.get_fields(idx // self.ncrop) in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop anchor = self.anchors[icrop]
in_fields = crop(in_fields, start, self.crop, self.pad) for d, shift in enumerate(self.aug_shift):
tgt_fields = crop(tgt_fields, start * self.scale_factor, if shift is not None:
anchor[d] += torch.randint(shift, (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor,
np.zeros_like(self.pad)) np.zeros_like(self.pad), self.size)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
@ -176,12 +170,21 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields return in_fields, tgt_fields
def crop(fields, start, crop, pad): def crop(fields, anchor, crop, pad, size):
ndim = len(size)
assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim'
new_fields = [] new_fields = []
for x in fields: for x in fields:
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)): ind = [slice(None)]
begin, end = i - p0, i + c + p1 for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
x = x.take(range(begin, end), axis=1 + d, mode='wrap') i = np.arange(a - p0, a + c + p1)
i %= s
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
ind.append(i)
x = x[tuple(ind)]
x.setflags(write=True) # workaround numpy bug before 1.18
new_fields.append(x) new_fields.append(x)

View File

@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable from matplotlib.cm import ScalarMappable
def fig3d(*fields, size=64, title=None, cmap=None, norm=None): def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
"""Plot slices of fields of more than 2 spatial dimensions.
"""
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor) fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
else field for field in fields] else field for field in fields]
assert all(isinstance(field, np.ndarray) for field in fields) assert all(isinstance(field, np.ndarray) for field in fields)
assert all(field.ndim == fields[0].ndim for field in fields)
nc = max(field.shape[0] for field in fields) nc = max(field.shape[0] for field in fields)
nf = len(fields) nf = len(fields)
nd = fields[0].ndim - 1
if title is not None: if title is not None:
assert len(title) == nf assert len(title) == nf
@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
norm_ = norm norm_ = norm
for c in range(field.shape[0]): for c in range(field.shape[0]):
axes[c, f].pcolormesh(field[c, 0, :size, :size], s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
cmap=cmap_, norm=norm_) axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
axes[c, f].set_aspect('equal') axes[c, f].set_aspect('equal')

View File

@ -3,7 +3,10 @@ from .vnet import VNet, VNetFat
from .pyramid import PyramidNet from .pyramid import PyramidNet
from .patchgan import PatchGAN, PatchGAN42 from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like from .narrow import narrow_by, narrow_cast, narrow_like
from .resample import resample, Resampler
from .lag2eul import Lag2Eul
from .lag2eul import Lag2Eul from .lag2eul import Lag2Eul

View File

@ -19,7 +19,6 @@ def adv_criterion_wrapper(module):
"""Wrap an adversarial criterion to: """Wrap an adversarial criterion to:
* also take lists of Tensors as target, used to split the input Tensor * also take lists of Tensors as target, used to split the input Tensor
along the batch dimension along the batch dimension
* enable min reduction on input
* expand target shape as that of input * expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors * return a list of losses, one for each pair of input and target Tensors
""" """
@ -34,17 +33,8 @@ def adv_criterion_wrapper(module):
input = self.split_input(input, target) input = self.split_input(input, target)
assert len(input) == len(target) assert len(input) == len(target)
if self.reduction == 'min':
input = [torch.min(i).unsqueeze(0) for i in input]
target = [t.expand_as(i) for i, t in zip(input, target)] target = [t.expand_as(i) for i, t in zip(input, target)]
if self.reduction == 'min':
self.reduction = 'mean' # average over batches
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
self.reduction = 'min'
else:
loss = [super(new_module, self).forward(i, t) loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)] for i, t in zip(input, target)]

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .narrow import narrow_like
from .swish import Swish from .swish import Swish
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
x = self.act(x) x = self.act(x)
return x return x
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference,
consistent with the downsampling.
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

View File

@ -1,17 +1,14 @@
from math import log
import torch import torch
class InstanceNoise: class InstanceNoise:
"""Instance noise, with a heuristic annealing schedule """Instance noise, with a linear decaying schedule
""" """
def __init__(self, init_std, batches): def __init__(self, init_std, batches):
assert init_std >= 0, 'Noise std cannot be negative'
self.init_std = init_std self.init_std = init_std
self.anneal = 1 self._std = init_std
self.ln2 = log(2)
self.batches = batches self.batches = batches
def std(self, adv_loss): def std(self):
self.anneal -= adv_loss / self.ln2 / self.batches self._std -= self.init_std / self.batches
std = self.anneal * self.init_std return max(self._std, 0)
std = std if std > 0 else 0
return std

43
map2map/models/narrow.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
for d in range(2, a.dim()):
a = a.narrow(d, c, a.shape[d] - 2 * c)
return a
def narrow_cast(*tensors):
"""Narrow each tensor to the minimum length in each dimension.
Try to be symmetric but cut more on the right for odd difference
"""
dim_max = max(a.dim() for a in tensors)
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
casted_tensors = []
for a in tensors:
for d in range(2, dim_max):
width = a.shape[d] - len_min[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
casted_tensors.append(a)
return casted_tensors
def narrow_like(a, b):
"""Narrow a to be like b.
Try to be symmetric but cut more on the right for odd difference
"""
for d in range(2, a.dim()):
width = a.shape[d] - b.shape[d]
half_width = width // 2
a = a.narrow(d, half_width, a.shape[d] - width)
return a

View File

@ -0,0 +1,46 @@
import torch.nn as nn
import torch.nn.functional as F
from .narrow import narrow_by
def resample(x, scale_factor, narrow=True):
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
ndim = x.dim() - 2
mode = modes[ndim]
x = F.interpolate(x, scale_factor=scale_factor,
mode=mode, align_corners=False)
if scale_factor > 1 and narrow == True:
edges = round(scale_factor) // 2
edges = max(edges, 1)
x = narrow_by(x, edges)
return x
class Resampler(nn.Module):
"""Resampling, upsampling or downsampling.
By default discard the inaccurate edges when upsampling.
"""
def __init__(self, ndim, scale_factor, narrow=True):
super().__init__()
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
self.mode = modes[ndim]
self.scale_factor = scale_factor
self.narrow = narrow
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor,
mode=self.mode, align_corners=False)
if self.scale_factor > 1 and self.narrow == True:
edges = round(self.scale_factor) // 2
edges = max(edges, 1)
x = narrow_by(x, edges)
return x

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock, narrow_like from .conv import ConvBlock
from .narrow import narrow_like
class UNet(nn.Module): class UNet(nn.Module):

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like from .conv import ConvBlock, ResBlock
from .narrow import narrow_like
class VNet(nn.Module): class VNet(nn.Module):

View File

@ -1,12 +1,12 @@
import sys
from pprint import pprint from pprint import pprint
import numpy as np import numpy as np
import torch import torch
import sys
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .data import FieldDataset from .data import FieldDataset
from . import models from . import models
from .models import narrow_like from .models import narrow_cast
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
@ -21,12 +21,15 @@ def test(args):
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at, callback_at=args.callback_at,
augment=False, augment=False,
aug_shift=None,
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
crop=args.crop, crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
cache=args.cache,
) )
test_loader = DataLoader( test_loader = DataLoader(
test_dataset, test_dataset,
@ -54,12 +57,7 @@ def test(args):
with torch.no_grad(): with torch.no_grad():
for i, (input, target) in enumerate(test_loader): for i, (input, target) in enumerate(test_loader):
output = model(input) output = model(input)
if args.pad > 0: # FIXME input, output, target = narrow_cast(input, output, target)
output = narrow_like(output, target)
input = narrow_like(input, target)
else:
target = narrow_like(target, output)
input = narrow_like(input, output)
loss = criterion(output, target) loss = criterion(output, target)

View File

@ -15,9 +15,9 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler from .data import FieldDataset, GroupedRandomSampler
from .data.figures import fig3d from .data.figures import plt_slices
from . import models from . import models
from .models import (narrow_like, Lag2Eul, from .models import (narrow_like, resample, Lag2Eul,
adv_model_wrapper, adv_criterion_wrapper, adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm, add_spectral_norm, rm_spectral_norm,
InstanceNoise) InstanceNoise)
@ -65,28 +65,17 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at, callback_at=args.callback_at,
augment=args.augment, augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add, aug_add=args.aug_add,
aug_mul=args.aug_mul, aug_mul=args.aug_mul,
crop=args.crop, crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
cache=args.cache,
cache_maxsize=args.cache_maxsize,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
) )
if args.div_data:
train_sampler = GroupedRandomSampler(
train_dataset,
group_size=None if args.cache_maxsize is None else
args.cache_maxsize * train_dataset.ncrop,
)
else:
try:
train_sampler = DistributedSampler(train_dataset, shuffle=True) train_sampler = DistributedSampler(train_dataset, shuffle=True)
except TypeError:
train_sampler = DistributedSampler(train_dataset) # old pytorch
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batches, batch_size=args.batches,
@ -104,24 +93,17 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at, callback_at=args.callback_at,
augment=False, augment=False,
aug_shift=None,
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
crop=args.crop, crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
cache=args.cache,
cache_maxsize=None if args.cache_maxsize is None else 1,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
) )
if args.div_data:
val_sampler = None
else:
try:
val_sampler = DistributedSampler(val_dataset, shuffle=False) val_sampler = DistributedSampler(val_dataset, shuffle=False)
except TypeError:
val_sampler = DistributedSampler(val_dataset) # old pytorch
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batches, batch_size=args.batches,
@ -170,7 +152,7 @@ def gpu_worker(local_rank, node, args):
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at) adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion) adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean') adv_criterion = adv_criterion()
adv_criterion.to(device) adv_criterion.to(device)
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at) adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
@ -246,7 +228,6 @@ def gpu_worker(local_rank, node, args):
args.instance_noise_batches) args.instance_noise_batches)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
if not args.div_data:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, train_loss = train(epoch, train_loader,
@ -267,10 +248,7 @@ def gpu_worker(local_rank, node, args):
adv_scheduler.step(epoch_loss[0]) adv_scheduler.step(epoch_loss[0])
if rank == 0: if rank == 0:
try:
logger.flush() logger.flush()
except AttributeError:
logger.close() # old pytorch
if ((min_loss is None or epoch_loss[0] < min_loss[0]) if ((min_loss is None or epoch_loss[0] < min_loss[0])
and epoch >= args.adv_start): and epoch >= args.adv_start):
@ -293,12 +271,6 @@ def gpu_worker(local_rank, node, args):
os.symlink(state_file, tmp_link) # workaround to overwrite os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link) os.rename(tmp_link, ckpt_link)
if args.cache:
print('rank {} train data: {}'.format(
rank, train_dataset.get_fields.cache_info()))
print('rank {} val data: {}'.format(
rank, val_dataset.get_fields.cache_info()))
dist.destroy_process_group() dist.destroy_process_group()
@ -327,12 +299,15 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
target = target.to(device, non_blocking=True) target = target.to(device, non_blocking=True)
output = model(input) output = model(input)
if epoch == 0 and i == 0 and rank == 0:
print('input.shape =', input.shape)
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
target = narrow_like(target, output) # FIXME pad if (hasattr(model.module, 'scale_factor')
if hasattr(model, 'scale_factor') and model.scale_factor != 1: and model.module.scale_factor != 1):
input = F.interpolate(input, input = resample(input, model.module.scale_factor, narrow=False)
scale_factor=model.scale_factor, mode='nearest') input, output, target = narrow_cast(input, output, target)
input = narrow_like(input, output)
output, target = dis2den(output, target) output, target = dis2den(output, target)
@ -340,13 +315,11 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start:
try: noise_std = args.instance_noise.std()
noise_std = args.instance_noise.std(adv_loss)
except NameError:
noise_std = args.instance_noise.std(0)
if noise_std > 0: if noise_std > 0:
noise = noise_std * torch.randn_like(output) noise = noise_std * torch.randn_like(output)
output = output + noise.detach() output = output + noise.detach()
noise = noise_std * torch.randn_like(target)
target = target + noise.detach() target = target + noise.detach()
del noise del noise
@ -405,21 +378,19 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
if '.weight' in n) if '.weight' in n)
grads = [grads[0], grads[-1]] grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads] grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad', { logger.add_scalar('grad/first', grads[0], global_step=batch)
'first': grads[0], logger.add_scalar('grad/last', grads[-1], global_step=batch)
'last': grads[-1],
}, global_step=batch)
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start:
grads = list(p.grad for n, p in adv_model.named_parameters() grads = list(p.grad for n, p in adv_model.named_parameters()
if '.weight' in n) if '.weight' in n)
grads = [grads[0], grads[-1]] grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads] grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad/adv', { logger.add_scalars('grad/adv/first', grads[0],
'first': grads[0], global_step=batch)
'last': grads[-1], logger.add_scalars('grad/adv/last', grads[-1],
}, global_step=batch) global_step=batch)
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start and noise_std > 0:
logger.add_scalar('instance_noise', noise_std, logger.add_scalar('instance_noise', noise_std,
global_step=batch) global_step=batch)
@ -440,7 +411,7 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
skip_chan = 0 skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan: if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan) skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/train', fig3d( logger.add_figure('fig/epoch/train', plt_slices(
input[-1], input[-1],
output[-1, skip_chan:], output[-1, skip_chan:],
target[-1, skip_chan:], target[-1, skip_chan:],
@ -471,11 +442,10 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
output = model(input) output = model(input)
target = narrow_like(target, output) # FIXME pad if (hasattr(model.module, 'scale_factor')
if hasattr(model, 'scale_factor') and model.scale_factor != 1: and model.module.scale_factor != 1):
input = F.interpolate(input, input = resample(input, model.module.scale_factor, narrow=False)
scale_factor=model.scale_factor, mode='nearest') input, output, target = narrow_cast(input, output, target)
input = narrow_like(input, output)
output, target = dis2den(output, target) output, target = dis2den(output, target)
@ -517,7 +487,7 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
skip_chan = 0 skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan: if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan) skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/val', fig3d( logger.add_figure('fig/epoch/val', plt_slices(
input[-1], input[-1],
output[-1, skip_chan:], output[-1, skip_chan:],
target[-1, skip_chan:], target[-1, skip_chan:],

View File

@ -38,8 +38,7 @@ srun m2m.py train \
--in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \ --in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \
--model UNet \ --model UNet \
--lr 0.0001 --batches 1 --loader-workers 0 \ --lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM
--cache --div-data
date date

View File

@ -38,8 +38,7 @@ m2m.py test \
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \ --in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
--model VNet \ --model VNet \
--load-state best_model.pt \ --load-state best_model.pt \
--batches 1 --loader-workers 0 \ --batches 1 --loader-workers 0
--cache
date date

View File

@ -39,8 +39,7 @@ srun m2m.py train \
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \ --in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
--model VNet --adv-model UNet --cgan \ --model VNet --adv-model UNet --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM
--cache --div-data
date date

View File

@ -39,10 +39,7 @@ srun m2m.py train \
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \ --in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
--model VNet --adv-model PatchGAN --cgan \ --model VNet --adv-model PatchGAN --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM
--cache --div-data
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
date date

View File

@ -38,8 +38,7 @@ m2m.py test \
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \ --in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
--model VNet \ --model VNet \
--load-state best_model.pt \ --load-state best_model.pt \
--batches 1 --loader-workers 0 \ --batches 1 --loader-workers 0
--cache
date date

View File

@ -39,8 +39,7 @@ srun m2m.py train \
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \ --in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
--model VNet --adv-model UNet --cgan \ --model VNet --adv-model UNet --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM
--cache --div-data
date date

View File

@ -10,7 +10,7 @@ setup(
packages=find_packages(), packages=find_packages(),
python_requires='>=3.6', python_requires='>=3.6',
install_requires=[ install_requires=[
'torch', 'torch>=1.2',
'numpy', 'numpy',
'scipy', 'scipy',
], ],