Merge branch 'master' into lag2eul
This commit is contained in:
commit
3437b20ed8
15
README.md
15
README.md
@ -33,19 +33,22 @@ For all command line options look at `map2map/args.py` or do `m2m.py -h`.
|
|||||||
|
|
||||||
Put each field in one npy file.
|
Put each field in one npy file.
|
||||||
Structure your data to start with the channel axis and then the spatial
|
Structure your data to start with the channel axis and then the spatial
|
||||||
dimensions.
|
dimensions, e.g. `(2, 64, 64)` for a 2D vector field of size `64^2` and
|
||||||
For example a 2D vector field of size `64^2` should have shape `(2, 64,
|
`(1, 32, 32, 32)` for a 3D scalar field of size `32^3`.
|
||||||
64)`.
|
|
||||||
Specify the data path with
|
Specify the data path with
|
||||||
[glob patterns](https://docs.python.org/3/library/glob.html).
|
[glob patterns](https://docs.python.org/3/library/glob.html).
|
||||||
|
|
||||||
During training, pairs of input and target fields are loaded.
|
During training, pairs of input and target fields are loaded.
|
||||||
Both input and target data can consist of multiple fields, which are
|
Both input and target data can consist of multiple fields, which are
|
||||||
then concatenated along the channel axis.
|
then concatenated along the channel axis.
|
||||||
|
|
||||||
|
|
||||||
|
#### Data cropping
|
||||||
|
|
||||||
If the size of a pair of input and target fields is too large to fit in
|
If the size of a pair of input and target fields is too large to fit in
|
||||||
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
|
a GPU, we can crop part of them to form pairs of samples.
|
||||||
Each field can be cropped multiple times, along each dimension,
|
Each field can be cropped multiple times, along each dimension.
|
||||||
controlled by the spacing between two adjacent crops (see `--step`).
|
See `--crop`, `--crop-start`, `--crop-stop`, and `--crop-step`.
|
||||||
The total sample size is the number of input and target pairs multiplied
|
The total sample size is the number of input and target pairs multiplied
|
||||||
by the number of cropped samples per pair.
|
by the number of cropped samples per pair.
|
||||||
|
|
||||||
|
@ -40,39 +40,41 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||||
'of target normalization functions from .data.norms')
|
'of target normalization functions from .data.norms')
|
||||||
parser.add_argument('--crop', type=int,
|
parser.add_argument('--crop', type=int,
|
||||||
help='size to crop the input and target data')
|
help='size to crop the input and target data. Default is the '
|
||||||
|
'field size')
|
||||||
|
parser.add_argument('--crop-start', type=int,
|
||||||
|
help='starting point of the first crop. Default is the origin')
|
||||||
|
parser.add_argument('--crop-stop', type=int,
|
||||||
|
help='stopping point of the last crop. Default is the opposite '
|
||||||
|
'corner to the origin')
|
||||||
|
parser.add_argument('--crop-step', type=int,
|
||||||
|
help='spacing between crops. Default is the crop size')
|
||||||
parser.add_argument('--pad', default=0, type=int,
|
parser.add_argument('--pad', default=0, type=int,
|
||||||
help='size to pad the input data beyond the crop size, assuming '
|
help='size to pad the input data beyond the crop size, assuming '
|
||||||
'periodic boundary condition')
|
'periodic boundary condition')
|
||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
help='input upsampling factor for super-resolution purpose, in '
|
help='upsampling factor for super-resolution, in which case '
|
||||||
'which case crop and pad will be taken at the original resolution')
|
'crop and pad are sizes of the input resolution')
|
||||||
|
|
||||||
parser.add_argument('--model', required=True, type=str,
|
parser.add_argument('--model', type=str, required=True,
|
||||||
help='model from .models')
|
help='model from .models')
|
||||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
parser.add_argument('--criterion', default='MSELoss', type=str,
|
||||||
help='model criterion from torch.nn')
|
help='model criterion from torch.nn')
|
||||||
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
||||||
help='path to load the states of model, optimizer, rng, etc. '
|
help='path to load the states of model, optimizer, rng, etc. '
|
||||||
'Default is the checkpoint. '
|
'Default is the checkpoint. '
|
||||||
'Start from scratch if set empty or the checkpoint is missing')
|
'Start from scratch in case of empty string or missing checkpoint')
|
||||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||||
help='allow incompatible keys when loading model states',
|
help='allow incompatible keys when loading model states',
|
||||||
dest='load_state_strict')
|
dest='load_state_strict')
|
||||||
|
|
||||||
parser.add_argument('--batches', default=1, type=int,
|
parser.add_argument('--batches', type=int, required=True,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', type=int,
|
parser.add_argument('--loader-workers', default=-8, type=int,
|
||||||
help='number of data loading workers, per GPU in training or '
|
help='number of subprocesses per data loader. '
|
||||||
'in total in testing. '
|
'0 to disable multiprocessing; '
|
||||||
'Default is the batch size or 0 for batch size 1')
|
'negative number to multiply by the batch size')
|
||||||
|
|
||||||
parser.add_argument('--cache', action='store_true',
|
|
||||||
help='enable LRU cache of input and target fields to reduce I/O')
|
|
||||||
parser.add_argument('--cache-maxsize', type=int,
|
|
||||||
help='maximum pairs of fields in cache, unlimited by default. '
|
|
||||||
'This only applies to training if not None, '
|
|
||||||
'in which case the testing cache maxsize is 1')
|
|
||||||
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
||||||
help='directory of custorm code defining callbacks for models, '
|
help='directory of custorm code defining callbacks for models, '
|
||||||
'norms, criteria, and optimizers. Disabled if not set. '
|
'norms, criteria, and optimizers. Disabled if not set. '
|
||||||
@ -93,6 +95,10 @@ def add_train_args(parser):
|
|||||||
help='comma-sep. list of glob patterns for validation target data')
|
help='comma-sep. list of glob patterns for validation target data')
|
||||||
parser.add_argument('--augment', action='store_true',
|
parser.add_argument('--augment', action='store_true',
|
||||||
help='enable data augmentation of axis flipping and permutation')
|
help='enable data augmentation of axis flipping and permutation')
|
||||||
|
parser.add_argument('--aug-shift', type=int,
|
||||||
|
help='data augmentation by shifting [0, aug_shift) pixels, '
|
||||||
|
'useful for models that treat neighboring pixels differently, '
|
||||||
|
'e.g. with strided convolutions')
|
||||||
parser.add_argument('--aug-add', type=float,
|
parser.add_argument('--aug-add', type=float,
|
||||||
help='additive data augmentation, (normal) std, '
|
help='additive data augmentation, (normal) std, '
|
||||||
'same factor for all fields')
|
'same factor for all fields')
|
||||||
@ -106,8 +112,6 @@ def add_train_args(parser):
|
|||||||
help='enable spectral normalization on the adversary model')
|
help='enable spectral normalization on the adversary model')
|
||||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||||
help='adversarial criterion from torch.nn')
|
help='adversarial criterion from torch.nn')
|
||||||
parser.add_argument('--min-reduction', action='store_true',
|
|
||||||
help='enable minimum reduction in adversarial criterion')
|
|
||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
parser.add_argument('--adv-start', default=0, type=int,
|
parser.add_argument('--adv-start', default=0, type=int,
|
||||||
@ -124,7 +128,7 @@ def add_train_args(parser):
|
|||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
parser.add_argument('--lr', default=0.001, type=float,
|
parser.add_argument('--lr', type=float, required=True,
|
||||||
help='initial learning rate')
|
help='initial learning rate')
|
||||||
# parser.add_argument('--momentum', default=0.9, type=float,
|
# parser.add_argument('--momentum', default=0.9, type=float,
|
||||||
# help='momentum')
|
# help='momentum')
|
||||||
@ -143,8 +147,6 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
help='seed for initializing training')
|
help='seed for initializing training')
|
||||||
|
|
||||||
parser.add_argument('--div-data', action='store_true',
|
|
||||||
help='enable data division among GPUs, useful with cache')
|
|
||||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||||
choices=['gloo', 'nccl'], help='distributed backend')
|
choices=['gloo', 'nccl'], help='distributed backend')
|
||||||
parser.add_argument('--log-interval', default=100, type=int,
|
parser.add_argument('--log-interval', default=100, type=int,
|
||||||
@ -175,21 +177,8 @@ def str_list(s):
|
|||||||
|
|
||||||
|
|
||||||
def set_common_args(args):
|
def set_common_args(args):
|
||||||
if args.loader_workers is None:
|
if args.loader_workers < 0:
|
||||||
args.loader_workers = 0
|
args.loader_workers *= - args.batches
|
||||||
if args.batches > 1:
|
|
||||||
args.loader_workers = args.batches
|
|
||||||
|
|
||||||
if not args.cache and args.cache_maxsize is not None:
|
|
||||||
args.cache_maxsize = None
|
|
||||||
warnings.warn('Resetting cache maxsize given cache is disabled',
|
|
||||||
RuntimeWarning)
|
|
||||||
if (args.cache and args.cache_maxsize is not None
|
|
||||||
and args.cache_maxsize < 1):
|
|
||||||
args.cache = False
|
|
||||||
args.cache_maxsize = None
|
|
||||||
warnings.warn('Disabling cache given cache maxsize < 1',
|
|
||||||
RuntimeWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def set_train_args(args):
|
def set_train_args(args):
|
||||||
@ -206,6 +195,11 @@ def set_train_args(args):
|
|||||||
if args.adv_weight_decay is None:
|
if args.adv_weight_decay is None:
|
||||||
args.adv_weight_decay = args.weight_decay
|
args.adv_weight_decay = args.weight_decay
|
||||||
|
|
||||||
|
if args.cgan and not args.adv:
|
||||||
|
args.cgan =False
|
||||||
|
warnings.warn('Disabling cgan given adversary is disabled',
|
||||||
|
RuntimeWarning)
|
||||||
|
|
||||||
|
|
||||||
def set_test_args(args):
|
def set_test_args(args):
|
||||||
set_common_args(args)
|
set_common_args(args)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
from functools import lru_cache
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -14,6 +13,7 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
`in_patterns` is a list of glob patterns for the input field files.
|
`in_patterns` is a list of glob patterns for the input field files.
|
||||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||||
|
Each pattern in the list is a new field.
|
||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target fields are matched by sorting the globbed files.
|
Input and target fields are matched by sorting the globbed files.
|
||||||
|
|
||||||
@ -21,29 +21,29 @@ class FieldDataset(Dataset):
|
|||||||
Likewise for `tgt_norms`.
|
Likewise for `tgt_norms`.
|
||||||
|
|
||||||
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||||
In 3D these form the full octahedral symmetry known as the Oh point group.
|
In 3D these form the full octahedral symmetry, the Oh group of order 48.
|
||||||
|
In 2D this is the dihedral group D4 of order 8.
|
||||||
|
1D is not supported, but can be done easily by preprocessing.
|
||||||
|
Fields can be augmented by random shift by a few pixels, useful for models
|
||||||
|
that treat neighboring pixels differently, e.g. with strided convolutions.
|
||||||
Additive and multiplicative augmentation are also possible, but with all fields
|
Additive and multiplicative augmentation are also possible, but with all fields
|
||||||
added or multiplied by the same factor.
|
added or multiplied by the same factor.
|
||||||
|
|
||||||
Input and target fields can be cropped.
|
Input and target fields can be cropped, to return multiple slices of size
|
||||||
Input fields can be padded assuming periodic boundary condition.
|
`crop` from each field.
|
||||||
|
The crop anchors are controlled by `crop_start`, `crop_stop`, and `crop_step`.
|
||||||
|
Input (but not target) fields can be padded beyond the crop size assuming
|
||||||
|
periodic boundary condition.
|
||||||
|
|
||||||
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
the input resolution.
|
the input resolution.
|
||||||
|
|
||||||
`cache` enables LRU cache of the input and target fields, up to `cache_maxsize`
|
|
||||||
pairs (unlimited by default).
|
|
||||||
`div_data` enables data division, to be used with `cache`, so that different
|
|
||||||
fields are cached in different GPU processes.
|
|
||||||
This saves CPU RAM but limits stochasticity.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, pad=0, scale_factor=1,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
cache=False, cache_maxsize=None, div_data=False,
|
pad=0, scale_factor=1):
|
||||||
rank=None, world_size=None):
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
@ -54,12 +54,14 @@ class FieldDataset(Dataset):
|
|||||||
'number of input and target fields do not match'
|
'number of input and target fields do not match'
|
||||||
self.nfile = len(self.in_files)
|
self.nfile = len(self.in_files)
|
||||||
|
|
||||||
assert self.nfile > 0, 'file not found'
|
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
||||||
|
|
||||||
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
self.in_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
for f in self.in_files[0]]
|
||||||
|
self.tgt_chan = [np.load(f, mmap_mode='r').shape[0]
|
||||||
|
for f in self.tgt_files[0]]
|
||||||
|
|
||||||
self.size = np.load(self.in_files[0][0]).shape[1:]
|
self.size = np.load(self.in_files[0][0], mmap_mode='r').shape[1:]
|
||||||
self.size = np.asarray(self.size)
|
self.size = np.asarray(self.size)
|
||||||
self.ndim = len(self.size)
|
self.ndim = len(self.size)
|
||||||
|
|
||||||
@ -80,16 +82,35 @@ class FieldDataset(Dataset):
|
|||||||
self.augment = augment
|
self.augment = augment
|
||||||
if self.ndim == 1 and self.augment:
|
if self.ndim == 1 and self.augment:
|
||||||
raise ValueError('cannot augment 1D fields')
|
raise ValueError('cannot augment 1D fields')
|
||||||
|
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
|
||||||
self.aug_add = aug_add
|
self.aug_add = aug_add
|
||||||
self.aug_mul = aug_mul
|
self.aug_mul = aug_mul
|
||||||
|
|
||||||
if crop is None:
|
if crop is None:
|
||||||
self.crop = self.size
|
self.crop = self.size
|
||||||
self.reps = np.ones_like(self.size)
|
|
||||||
else:
|
else:
|
||||||
self.crop = np.broadcast_to(crop, self.size.shape)
|
self.crop = np.broadcast_to(crop, (self.ndim,))
|
||||||
self.reps = self.size // self.crop
|
|
||||||
self.ncrop = int(np.prod(self.reps))
|
if crop_start is None:
|
||||||
|
crop_start = np.zeros_like(self.size)
|
||||||
|
else:
|
||||||
|
crop_start = np.broadcast_to(crop_start, (self.ndim,))
|
||||||
|
|
||||||
|
if crop_stop is None:
|
||||||
|
crop_stop = self.size
|
||||||
|
else:
|
||||||
|
crop_stop = np.broadcast_to(crop_stop, (self.ndim,))
|
||||||
|
|
||||||
|
if crop_step is None:
|
||||||
|
crop_step = self.crop
|
||||||
|
else:
|
||||||
|
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
||||||
|
|
||||||
|
self.anchors = np.stack(np.mgrid[tuple(
|
||||||
|
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||||
|
for d in range(self.ndim)
|
||||||
|
)], axis=-1).reshape(-1, self.ndim)
|
||||||
|
self.ncrop = len(self.anchors)
|
||||||
|
|
||||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||||
@ -98,52 +119,25 @@ class FieldDataset(Dataset):
|
|||||||
'only support integer upsampling'
|
'only support integer upsampling'
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
if cache:
|
|
||||||
self.get_fields = lru_cache(maxsize=cache_maxsize)(self.get_fields)
|
|
||||||
|
|
||||||
if div_data:
|
|
||||||
self.samples = []
|
|
||||||
|
|
||||||
# first add full fields when num_fields > num_GPU
|
|
||||||
for i in range(rank, self.nfile // world_size * world_size,
|
|
||||||
world_size):
|
|
||||||
self.samples.append(
|
|
||||||
range(i * self.ncrop, (i + 1) * self.ncrop)
|
|
||||||
)
|
|
||||||
|
|
||||||
# then split the rest into fractions of fields
|
|
||||||
# drop the last incomplete batch of samples
|
|
||||||
frac_start = self.nfile // world_size * world_size * self.ncrop
|
|
||||||
frac_samples = self.nfile % world_size * self.ncrop // world_size
|
|
||||||
self.samples.append(range(frac_start + rank * frac_samples,
|
|
||||||
frac_start + (rank + 1) * frac_samples))
|
|
||||||
|
|
||||||
self.samples = np.concatenate(self.samples)
|
|
||||||
else:
|
|
||||||
self.samples = np.arange(self.nfile * self.ncrop)
|
|
||||||
self.nsample = len(self.samples)
|
|
||||||
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
def get_fields(self, idx):
|
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
|
||||||
return in_fields, tgt_fields
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nsample
|
return self.nfile * self.ncrop
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
idx = self.samples[idx]
|
ifile, icrop = divmod(idx, self.ncrop)
|
||||||
|
|
||||||
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
|
in_fields = [np.load(f, mmap_mode='r') for f in self.in_files[ifile]]
|
||||||
|
tgt_fields = [np.load(f, mmap_mode='r') for f in self.tgt_files[ifile]]
|
||||||
|
|
||||||
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop
|
anchor = self.anchors[icrop]
|
||||||
|
|
||||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
for d, shift in enumerate(self.aug_shift):
|
||||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
if shift is not None:
|
||||||
|
anchor[d] += torch.randint(shift, (1,))
|
||||||
|
|
||||||
|
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
||||||
|
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
np.zeros_like(self.pad))
|
np.zeros_like(self.pad), self.size)
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
@ -176,12 +170,21 @@ class FieldDataset(Dataset):
|
|||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, start, crop, pad):
|
def crop(fields, anchor, crop, pad, size):
|
||||||
|
ndim = len(size)
|
||||||
|
assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim'
|
||||||
|
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
ind = [slice(None)]
|
||||||
begin, end = i - p0, i + c + p1
|
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
||||||
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
i = np.arange(a - p0, a + c + p1)
|
||||||
|
i %= s
|
||||||
|
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||||
|
ind.append(i)
|
||||||
|
|
||||||
|
x = x[tuple(ind)]
|
||||||
|
x.setflags(write=True) # workaround numpy bug before 1.18
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
|
@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
|||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
|
|
||||||
|
|
||||||
def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
||||||
|
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||||
|
"""
|
||||||
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
||||||
else field for field in fields]
|
else field for field in fields]
|
||||||
|
|
||||||
assert all(isinstance(field, np.ndarray) for field in fields)
|
assert all(isinstance(field, np.ndarray) for field in fields)
|
||||||
|
assert all(field.ndim == fields[0].ndim for field in fields)
|
||||||
|
|
||||||
nc = max(field.shape[0] for field in fields)
|
nc = max(field.shape[0] for field in fields)
|
||||||
nf = len(fields)
|
nf = len(fields)
|
||||||
|
nd = fields[0].ndim - 1
|
||||||
|
|
||||||
if title is not None:
|
if title is not None:
|
||||||
assert len(title) == nf
|
assert len(title) == nf
|
||||||
@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
|||||||
norm_ = norm
|
norm_ = norm
|
||||||
|
|
||||||
for c in range(field.shape[0]):
|
for c in range(field.shape[0]):
|
||||||
axes[c, f].pcolormesh(field[c, 0, :size, :size],
|
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
|
||||||
cmap=cmap_, norm=norm_)
|
axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
|
||||||
|
|
||||||
axes[c, f].set_aspect('equal')
|
axes[c, f].set_aspect('equal')
|
||||||
|
|
||||||
|
@ -3,7 +3,10 @@ from .vnet import VNet, VNetFat
|
|||||||
from .pyramid import PyramidNet
|
from .pyramid import PyramidNet
|
||||||
from .patchgan import PatchGAN, PatchGAN42
|
from .patchgan import PatchGAN, PatchGAN42
|
||||||
|
|
||||||
from .conv import narrow_like
|
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||||
|
from .resample import resample, Resampler
|
||||||
|
|
||||||
|
from .lag2eul import Lag2Eul
|
||||||
|
|
||||||
from .lag2eul import Lag2Eul
|
from .lag2eul import Lag2Eul
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ def adv_criterion_wrapper(module):
|
|||||||
"""Wrap an adversarial criterion to:
|
"""Wrap an adversarial criterion to:
|
||||||
* also take lists of Tensors as target, used to split the input Tensor
|
* also take lists of Tensors as target, used to split the input Tensor
|
||||||
along the batch dimension
|
along the batch dimension
|
||||||
* enable min reduction on input
|
|
||||||
* expand target shape as that of input
|
* expand target shape as that of input
|
||||||
* return a list of losses, one for each pair of input and target Tensors
|
* return a list of losses, one for each pair of input and target Tensors
|
||||||
"""
|
"""
|
||||||
@ -34,19 +33,10 @@ def adv_criterion_wrapper(module):
|
|||||||
input = self.split_input(input, target)
|
input = self.split_input(input, target)
|
||||||
assert len(input) == len(target)
|
assert len(input) == len(target)
|
||||||
|
|
||||||
if self.reduction == 'min':
|
|
||||||
input = [torch.min(i).unsqueeze(0) for i in input]
|
|
||||||
|
|
||||||
target = [t.expand_as(i) for i, t in zip(input, target)]
|
target = [t.expand_as(i) for i, t in zip(input, target)]
|
||||||
|
|
||||||
if self.reduction == 'min':
|
loss = [super(new_module, self).forward(i, t)
|
||||||
self.reduction = 'mean' # average over batches
|
for i, t in zip(input, target)]
|
||||||
loss = [super(new_module, self).forward(i, t)
|
|
||||||
for i, t in zip(input, target)]
|
|
||||||
self.reduction = 'min'
|
|
||||||
else:
|
|
||||||
loss = [super(new_module, self).forward(i, t)
|
|
||||||
for i, t in zip(input, target)]
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .narrow import narrow_like
|
||||||
from .swish import Swish
|
from .swish import Swish
|
||||||
|
|
||||||
|
|
||||||
@ -114,16 +115,3 @@ class ResBlock(ConvBlock):
|
|||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def narrow_like(a, b):
|
|
||||||
"""Narrow a to be like b.
|
|
||||||
|
|
||||||
Try to be symmetric but cut more on the right for odd difference,
|
|
||||||
consistent with the downsampling.
|
|
||||||
"""
|
|
||||||
for d in range(2, a.dim()):
|
|
||||||
width = a.shape[d] - b.shape[d]
|
|
||||||
half_width = width // 2
|
|
||||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
|
||||||
return a
|
|
||||||
|
@ -1,17 +1,14 @@
|
|||||||
from math import log
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class InstanceNoise:
|
class InstanceNoise:
|
||||||
"""Instance noise, with a heuristic annealing schedule
|
"""Instance noise, with a linear decaying schedule
|
||||||
"""
|
"""
|
||||||
def __init__(self, init_std, batches):
|
def __init__(self, init_std, batches):
|
||||||
|
assert init_std >= 0, 'Noise std cannot be negative'
|
||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.anneal = 1
|
self._std = init_std
|
||||||
self.ln2 = log(2)
|
|
||||||
self.batches = batches
|
self.batches = batches
|
||||||
|
|
||||||
def std(self, adv_loss):
|
def std(self):
|
||||||
self.anneal -= adv_loss / self.ln2 / self.batches
|
self._std -= self.init_std / self.batches
|
||||||
std = self.anneal * self.init_std
|
return max(self._std, 0)
|
||||||
std = std if std > 0 else 0
|
|
||||||
return std
|
|
||||||
|
43
map2map/models/narrow.py
Normal file
43
map2map/models/narrow.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_by(a, c):
|
||||||
|
"""Narrow a by size c symmetrically on all edges.
|
||||||
|
"""
|
||||||
|
for d in range(2, a.dim()):
|
||||||
|
a = a.narrow(d, c, a.shape[d] - 2 * c)
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_cast(*tensors):
|
||||||
|
"""Narrow each tensor to the minimum length in each dimension.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference
|
||||||
|
"""
|
||||||
|
dim_max = max(a.dim() for a in tensors)
|
||||||
|
|
||||||
|
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
|
||||||
|
|
||||||
|
casted_tensors = []
|
||||||
|
for a in tensors:
|
||||||
|
for d in range(2, dim_max):
|
||||||
|
width = a.shape[d] - len_min[d]
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||||
|
|
||||||
|
casted_tensors.append(a)
|
||||||
|
|
||||||
|
return casted_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_like(a, b):
|
||||||
|
"""Narrow a to be like b.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference
|
||||||
|
"""
|
||||||
|
for d in range(2, a.dim()):
|
||||||
|
width = a.shape[d] - b.shape[d]
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||||
|
return a
|
46
map2map/models/resample.py
Normal file
46
map2map/models/resample.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
|
||||||
|
def resample(x, scale_factor, narrow=True):
|
||||||
|
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
||||||
|
ndim = x.dim() - 2
|
||||||
|
mode = modes[ndim]
|
||||||
|
|
||||||
|
x = F.interpolate(x, scale_factor=scale_factor,
|
||||||
|
mode=mode, align_corners=False)
|
||||||
|
|
||||||
|
if scale_factor > 1 and narrow == True:
|
||||||
|
edges = round(scale_factor) // 2
|
||||||
|
edges = max(edges, 1)
|
||||||
|
x = narrow_by(x, edges)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resampler(nn.Module):
|
||||||
|
"""Resampling, upsampling or downsampling.
|
||||||
|
|
||||||
|
By default discard the inaccurate edges when upsampling.
|
||||||
|
"""
|
||||||
|
def __init__(self, ndim, scale_factor, narrow=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
modes = {1: 'linear', 2: 'bilinear', 3: 'trilinear'}
|
||||||
|
self.mode = modes[ndim]
|
||||||
|
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.narrow = narrow
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.interpolate(x, scale_factor=self.scale_factor,
|
||||||
|
mode=self.mode, align_corners=False)
|
||||||
|
|
||||||
|
if self.scale_factor > 1 and self.narrow == True:
|
||||||
|
edges = round(self.scale_factor) // 2
|
||||||
|
edges = max(edges, 1)
|
||||||
|
x = narrow_by(x, edges)
|
||||||
|
|
||||||
|
return x
|
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, narrow_like
|
from .conv import ConvBlock
|
||||||
|
from .narrow import narrow_like
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, ResBlock, narrow_like
|
from .conv import ConvBlock, ResBlock
|
||||||
|
from .narrow import narrow_like
|
||||||
|
|
||||||
|
|
||||||
class VNet(nn.Module):
|
class VNet(nn.Module):
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
|
import sys
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import sys
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import narrow_cast
|
||||||
from .utils import import_attr, load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -21,12 +21,15 @@ def test(args):
|
|||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
callback_at=args.callback_at,
|
callback_at=args.callback_at,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
aug_shift=None,
|
||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
crop_start=args.crop_start,
|
||||||
|
crop_stop=args.crop_stop,
|
||||||
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
@ -54,12 +57,7 @@ def test(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (input, target) in enumerate(test_loader):
|
for i, (input, target) in enumerate(test_loader):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.pad > 0: # FIXME
|
input, output, target = narrow_cast(input, output, target)
|
||||||
output = narrow_like(output, target)
|
|
||||||
input = narrow_like(input, target)
|
|
||||||
else:
|
|
||||||
target = narrow_like(target, output)
|
|
||||||
input = narrow_like(input, output)
|
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
106
map2map/train.py
106
map2map/train.py
@ -15,9 +15,9 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset, GroupedRandomSampler
|
from .data import FieldDataset, GroupedRandomSampler
|
||||||
from .data.figures import fig3d
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_like, Lag2Eul,
|
from .models import (narrow_like, resample, Lag2Eul,
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm,
|
add_spectral_norm, rm_spectral_norm,
|
||||||
InstanceNoise)
|
InstanceNoise)
|
||||||
@ -65,28 +65,17 @@ def gpu_worker(local_rank, node, args):
|
|||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
callback_at=args.callback_at,
|
callback_at=args.callback_at,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
|
aug_shift=args.aug_shift,
|
||||||
aug_add=args.aug_add,
|
aug_add=args.aug_add,
|
||||||
aug_mul=args.aug_mul,
|
aug_mul=args.aug_mul,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
crop_start=args.crop_start,
|
||||||
|
crop_stop=args.crop_stop,
|
||||||
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
|
||||||
cache_maxsize=args.cache_maxsize,
|
|
||||||
div_data=args.div_data,
|
|
||||||
rank=rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
)
|
)
|
||||||
if args.div_data:
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
train_sampler = GroupedRandomSampler(
|
|
||||||
train_dataset,
|
|
||||||
group_size=None if args.cache_maxsize is None else
|
|
||||||
args.cache_maxsize * train_dataset.ncrop,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
|
||||||
except TypeError:
|
|
||||||
train_sampler = DistributedSampler(train_dataset) # old pytorch
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -104,24 +93,17 @@ def gpu_worker(local_rank, node, args):
|
|||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
callback_at=args.callback_at,
|
callback_at=args.callback_at,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
aug_shift=None,
|
||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
crop_start=args.crop_start,
|
||||||
|
crop_stop=args.crop_stop,
|
||||||
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
|
||||||
cache_maxsize=None if args.cache_maxsize is None else 1,
|
|
||||||
div_data=args.div_data,
|
|
||||||
rank=rank,
|
|
||||||
world_size=args.world_size,
|
|
||||||
)
|
)
|
||||||
if args.div_data:
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
val_sampler = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
|
||||||
except TypeError:
|
|
||||||
val_sampler = DistributedSampler(val_dataset) # old pytorch
|
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -170,7 +152,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
||||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
adv_criterion = adv_criterion()
|
||||||
adv_criterion.to(device)
|
adv_criterion.to(device)
|
||||||
|
|
||||||
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||||
@ -246,8 +228,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
args.instance_noise_batches)
|
args.instance_noise_batches)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
if not args.div_data:
|
train_sampler.set_epoch(epoch)
|
||||||
train_sampler.set_epoch(epoch)
|
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader,
|
train_loss = train(epoch, train_loader,
|
||||||
model, dis2den, criterion, optimizer, scheduler,
|
model, dis2den, criterion, optimizer, scheduler,
|
||||||
@ -267,10 +248,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_scheduler.step(epoch_loss[0])
|
adv_scheduler.step(epoch_loss[0])
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
try:
|
logger.flush()
|
||||||
logger.flush()
|
|
||||||
except AttributeError:
|
|
||||||
logger.close() # old pytorch
|
|
||||||
|
|
||||||
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
||||||
and epoch >= args.adv_start):
|
and epoch >= args.adv_start):
|
||||||
@ -293,12 +271,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
os.rename(tmp_link, ckpt_link)
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
|
||||||
if args.cache:
|
|
||||||
print('rank {} train data: {}'.format(
|
|
||||||
rank, train_dataset.get_fields.cache_info()))
|
|
||||||
print('rank {} val data: {}'.format(
|
|
||||||
rank, val_dataset.get_fields.cache_info()))
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
@ -327,12 +299,15 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
|||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
if epoch == 0 and i == 0 and rank == 0:
|
||||||
|
print('input.shape =', input.shape)
|
||||||
|
print('output.shape =', output.shape)
|
||||||
|
print('target.shape =', target.shape, flush=True)
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
if (hasattr(model.module, 'scale_factor')
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
and model.module.scale_factor != 1):
|
||||||
input = F.interpolate(input,
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
input, output, target = narrow_cast(input, output, target)
|
||||||
input = narrow_like(input, output)
|
|
||||||
|
|
||||||
output, target = dis2den(output, target)
|
output, target = dis2den(output, target)
|
||||||
|
|
||||||
@ -340,13 +315,11 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
|||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start:
|
||||||
try:
|
noise_std = args.instance_noise.std()
|
||||||
noise_std = args.instance_noise.std(adv_loss)
|
|
||||||
except NameError:
|
|
||||||
noise_std = args.instance_noise.std(0)
|
|
||||||
if noise_std > 0:
|
if noise_std > 0:
|
||||||
noise = noise_std * torch.randn_like(output)
|
noise = noise_std * torch.randn_like(output)
|
||||||
output = output + noise.detach()
|
output = output + noise.detach()
|
||||||
|
noise = noise_std * torch.randn_like(target)
|
||||||
target = target + noise.detach()
|
target = target + noise.detach()
|
||||||
del noise
|
del noise
|
||||||
|
|
||||||
@ -405,21 +378,19 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
|||||||
if '.weight' in n)
|
if '.weight' in n)
|
||||||
grads = [grads[0], grads[-1]]
|
grads = [grads[0], grads[-1]]
|
||||||
grads = [g.detach().norm().item() for g in grads]
|
grads = [g.detach().norm().item() for g in grads]
|
||||||
logger.add_scalars('grad', {
|
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||||
'first': grads[0],
|
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
||||||
'last': grads[-1],
|
|
||||||
}, global_step=batch)
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start:
|
||||||
grads = list(p.grad for n, p in adv_model.named_parameters()
|
grads = list(p.grad for n, p in adv_model.named_parameters()
|
||||||
if '.weight' in n)
|
if '.weight' in n)
|
||||||
grads = [grads[0], grads[-1]]
|
grads = [grads[0], grads[-1]]
|
||||||
grads = [g.detach().norm().item() for g in grads]
|
grads = [g.detach().norm().item() for g in grads]
|
||||||
logger.add_scalars('grad/adv', {
|
logger.add_scalars('grad/adv/first', grads[0],
|
||||||
'first': grads[0],
|
global_step=batch)
|
||||||
'last': grads[-1],
|
logger.add_scalars('grad/adv/last', grads[-1],
|
||||||
}, global_step=batch)
|
global_step=batch)
|
||||||
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
||||||
logger.add_scalar('instance_noise', noise_std,
|
logger.add_scalar('instance_noise', noise_std,
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
|
|
||||||
@ -440,7 +411,7 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
|||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
skip_chan = sum(args.in_chan)
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/train', fig3d(
|
logger.add_figure('fig/epoch/train', plt_slices(
|
||||||
input[-1],
|
input[-1],
|
||||||
output[-1, skip_chan:],
|
output[-1, skip_chan:],
|
||||||
target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
@ -471,11 +442,10 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
|
|||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
if (hasattr(model.module, 'scale_factor')
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
and model.module.scale_factor != 1):
|
||||||
input = F.interpolate(input,
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
input, output, target = narrow_cast(input, output, target)
|
||||||
input = narrow_like(input, output)
|
|
||||||
|
|
||||||
output, target = dis2den(output, target)
|
output, target = dis2den(output, target)
|
||||||
|
|
||||||
@ -517,7 +487,7 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
|
|||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
skip_chan = sum(args.in_chan)
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/val', fig3d(
|
logger.add_figure('fig/epoch/val', plt_slices(
|
||||||
input[-1],
|
input[-1],
|
||||||
output[-1, skip_chan:],
|
output[-1, skip_chan:],
|
||||||
target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
|
@ -38,8 +38,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms torch.log1p --augment --crop 128 --pad 20 \
|
||||||
--model UNet \
|
--model UNet \
|
||||||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -38,8 +38,7 @@ m2m.py test \
|
|||||||
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pt \
|
--load-state best_model.pt \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0
|
||||||
--cache
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -39,8 +39,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
|
||||||
--model VNet --adv-model UNet --cgan \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -39,10 +39,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
||||||
--model VNet --adv-model PatchGAN --cgan \
|
--model VNet --adv-model PatchGAN --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
|
||||||
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -38,8 +38,7 @@ m2m.py test \
|
|||||||
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pt \
|
--load-state best_model.pt \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0
|
||||||
--cache
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -39,8 +39,7 @@ srun m2m.py train \
|
|||||||
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
|
||||||
--model VNet --adv-model UNet --cgan \
|
--model VNet --adv-model UNet --cgan \
|
||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM
|
||||||
--cache --div-data
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
Loading…
Reference in New Issue
Block a user