Add cropping anchors controlled by start, stop, step

This commit is contained in:
Yin Li 2020-05-04 23:30:59 -04:00
parent 819e77cd86
commit 8a95d69818
5 changed files with 66 additions and 24 deletions

View File

@ -33,19 +33,22 @@ For all command line options look at `map2map/args.py` or do `m2m.py -h`.
Put each field in one npy file. Put each field in one npy file.
Structure your data to start with the channel axis and then the spatial Structure your data to start with the channel axis and then the spatial
dimensions. dimensions, e.g. `(2, 64, 64)` for a 2D vector field of size `64^2` and
For example a 2D vector field of size `64^2` should have shape `(2, 64, `(1, 32, 32, 32)` for a 3D scalar field of size `32^3`.
64)`.
Specify the data path with Specify the data path with
[glob patterns](https://docs.python.org/3/library/glob.html). [glob patterns](https://docs.python.org/3/library/glob.html).
During training, pairs of input and target fields are loaded. During training, pairs of input and target fields are loaded.
Both input and target data can consist of multiple fields, which are Both input and target data can consist of multiple fields, which are
then concatenated along the channel axis. then concatenated along the channel axis.
#### Data cropping
If the size of a pair of input and target fields is too large to fit in If the size of a pair of input and target fields is too large to fit in
a GPU, we can crop part of them to form pairs of samples (see `--crop`). a GPU, we can crop part of them to form pairs of samples.
Each field can be cropped multiple times, along each dimension, Each field can be cropped multiple times, along each dimension.
controlled by the spacing between two adjacent crops (see `--step`). See `--crop`, `--crop-start`, `--crop-stop`, and `--crop-step`.
The total sample size is the number of input and target pairs multiplied The total sample size is the number of input and target pairs multiplied
by the number of cropped samples per pair. by the number of cropped samples per pair.

View File

@ -41,14 +41,21 @@ def add_common_args(parser):
'of target normalization functions from .data.norms') 'of target normalization functions from .data.norms')
parser.add_argument('--crop', type=int, parser.add_argument('--crop', type=int,
help='size to crop the input and target data') help='size to crop the input and target data')
parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int,
help='stopping point of the last crop. Default is the corner '
'opposite to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int, parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming ' help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition') 'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int, parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in ' help='upsampling factor for super-resolution, in which case '
'which case crop and pad will be taken at the original resolution') 'crop and pad are sizes of the input resolution')
parser.add_argument('--model', required=True, type=str, parser.add_argument('--model', type=str, required=True,
help='model from .models') help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str, parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn') help='model criterion from torch.nn')
@ -124,7 +131,7 @@ def add_train_args(parser):
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float, parser.add_argument('--lr', type=float, required=True,
help='initial learning rate') help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float, # parser.add_argument('--momentum', default=0.9, type=float,
# help='momentum') # help='momentum')

View File

@ -14,6 +14,7 @@ class FieldDataset(Dataset):
`in_patterns` is a list of glob patterns for the input field files. `in_patterns` is a list of glob patterns for the input field files.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`. For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Each pattern in the list is a new field.
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target fields are matched by sorting the globbed files. Input and target fields are matched by sorting the globbed files.
@ -25,8 +26,11 @@ class FieldDataset(Dataset):
Additive and multiplicative augmentation are also possible, but with all fields Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor. added or multiplied by the same factor.
Input and target fields can be cropped. Input and target fields can be cropped, to return multiple slices of size
Input fields can be padded assuming periodic boundary condition. `crop` from each field.
The crop anchors are controlled by `crop_start`, `crop_stop`, and `crop_step`.
Input (but not target) fields can be padded beyond the crop size assuming
periodic boundary condition.
Setting integer `scale_factor` greater than 1 will crop target bigger than Setting integer `scale_factor` greater than 1 will crop target bigger than
the input for super-resolution, in which case `crop` and `pad` are sizes of the input for super-resolution, in which case `crop` and `pad` are sizes of
@ -41,7 +45,8 @@ class FieldDataset(Dataset):
def __init__(self, in_patterns, tgt_patterns, def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None, in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None, augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1, crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False, cache=False, cache_maxsize=None, div_data=False,
rank=None, world_size=None): rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -54,7 +59,7 @@ class FieldDataset(Dataset):
'number of input and target fields do not match' 'number of input and target fields do not match'
self.nfile = len(self.in_files) self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found' assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]] self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]] self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
@ -85,11 +90,29 @@ class FieldDataset(Dataset):
if crop is None: if crop is None:
self.crop = self.size self.crop = self.size
self.reps = np.ones_like(self.size)
else: else:
self.crop = np.broadcast_to(crop, self.size.shape) self.crop = np.broadcast_to(crop, (self.ndim,))
self.reps = self.size // self.crop
self.ncrop = int(np.prod(self.reps)) if crop_start is None:
crop_start = np.zeros_like(self.size)
else:
crop_start = np.broadcast_to(crop_start, (self.ndim,))
if crop_stop is None:
crop_stop = self.size
else:
crop_stop = np.broadcast_to(crop_stop, (self.ndim,))
if crop_step is None:
crop_step = self.crop
else:
crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.anchors = np.stack(np.mgrid[tuple(
slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim)
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(pad, int), 'only support symmetric padding for now' assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2)) self.pad = np.broadcast_to(pad, (self.ndim, 2))
@ -138,10 +161,10 @@ class FieldDataset(Dataset):
in_fields, tgt_fields = self.get_fields(idx // self.ncrop) in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop anchor = self.anchors[idx % self.ncrop]
in_fields = crop(in_fields, start, self.crop, self.pad) in_fields = crop(in_fields, anchor, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor, tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor,
np.zeros_like(self.pad)) np.zeros_like(self.pad))
@ -176,11 +199,11 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields return in_fields, tgt_fields
def crop(fields, start, crop, pad): def crop(fields, anchor, crop, pad):
new_fields = [] new_fields = []
for x in fields: for x in fields:
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)): for d, (a, c, (p0, p1)) in enumerate(zip(anchor, crop, pad)):
begin, end = i - p0, i + c + p1 begin, end = a - p0, a + c + p1
x = x.take(range(begin, end), axis=1 + d, mode='wrap') x = x.take(range(begin, end), axis=1 + d, mode='wrap')
new_fields.append(x) new_fields.append(x)

View File

@ -24,6 +24,9 @@ def test(args):
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
crop=args.crop, crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
cache=args.cache, cache=args.cache,

View File

@ -68,6 +68,9 @@ def gpu_worker(local_rank, node, args):
aug_add=args.aug_add, aug_add=args.aug_add,
aug_mul=args.aug_mul, aug_mul=args.aug_mul,
crop=args.crop, crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
cache=args.cache, cache=args.cache,
@ -107,6 +110,9 @@ def gpu_worker(local_rank, node, args):
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
crop=args.crop, crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
cache=args.cache, cache=args.cache,