Add cropping anchors controlled by start, stop, step

This commit is contained in:
Yin Li 2020-05-04 23:30:59 -04:00
parent 819e77cd86
commit 8a95d69818
5 changed files with 66 additions and 24 deletions

View File

@ -33,19 +33,22 @@ For all command line options look at `map2map/args.py` or do `m2m.py -h`.
Put each field in one npy file.
Structure your data to start with the channel axis and then the spatial
dimensions.
For example a 2D vector field of size `64^2` should have shape `(2, 64,
64)`.
dimensions, e.g. `(2, 64, 64)` for a 2D vector field of size `64^2` and
`(1, 32, 32, 32)` for a 3D scalar field of size `32^3`.
Specify the data path with
[glob patterns](https://docs.python.org/3/library/glob.html).
During training, pairs of input and target fields are loaded.
Both input and target data can consist of multiple fields, which are
then concatenated along the channel axis.
#### Data cropping
If the size of a pair of input and target fields is too large to fit in
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
Each field can be cropped multiple times, along each dimension,
controlled by the spacing between two adjacent crops (see `--step`).
a GPU, we can crop part of them to form pairs of samples.
Each field can be cropped multiple times, along each dimension.
See `--crop`, `--crop-start`, `--crop-stop`, and `--crop-step`.
The total sample size is the number of input and target pairs multiplied
by the number of cropped samples per pair.

View File

@ -41,14 +41,21 @@ def add_common_args(parser):
'of target normalization functions from .data.norms')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int,
help='stopping point of the last crop. Default is the corner '
'opposite to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution')
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')
parser.add_argument('--model', required=True, type=str,
parser.add_argument('--model', type=str, required=True,
help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn')
@ -124,7 +131,7 @@ def add_train_args(parser):
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float,
parser.add_argument('--lr', type=float, required=True,
help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float,
# help='momentum')

View File

@ -14,6 +14,7 @@ class FieldDataset(Dataset):
`in_patterns` is a list of glob patterns for the input field files.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Each pattern in the list is a new field.
Likewise `tgt_patterns` is for target fields.
Input and target fields are matched by sorting the globbed files.
@ -25,8 +26,11 @@ class FieldDataset(Dataset):
Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor.
Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
Input and target fields can be cropped, to return multiple slices of size
`crop` from each field.
The crop anchors are controlled by `crop_start`, `crop_stop`, and `crop_step`.
Input (but not target) fields can be padded beyond the crop size assuming
periodic boundary condition.
Setting integer `scale_factor` greater than 1 will crop target bigger than
the input for super-resolution, in which case `crop` and `pad` are sizes of
@ -41,7 +45,8 @@ class FieldDataset(Dataset):
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False,
rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -54,7 +59,7 @@ class FieldDataset(Dataset):
'number of input and target fields do not match'
self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found'
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
@ -85,11 +90,29 @@ class FieldDataset(Dataset):
if crop is None:
self.crop = self.size
self.reps = np.ones_like(self.size)
else:
self.crop = np.broadcast_to(crop, self.size.shape)
self.reps = self.size // self.crop
self.ncrop = int(np.prod(self.reps))
self.crop = np.broadcast_to(crop, (self.ndim,))
if crop_start is None:
crop_start = np.zeros_like(self.size)
else:
crop_start = np.broadcast_to(crop_start, (self.ndim,))
if crop_stop is None:
crop_stop = self.size
else:
crop_stop = np.broadcast_to(crop_stop, (self.ndim,))
if crop_step is None:
crop_step = self.crop
else:
crop_step = np.broadcast_to(crop_step, (self.ndim,))
self.anchors = np.stack(np.mgrid[tuple(
slice(crop_start[d], crop_stop[d], crop_step[d])
for d in range(self.ndim)
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
@ -138,10 +161,10 @@ class FieldDataset(Dataset):
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop
anchor = self.anchors[idx % self.ncrop]
in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor,
in_fields = crop(in_fields, anchor, self.crop, self.pad)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor,
np.zeros_like(self.pad))
@ -176,11 +199,11 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields
def crop(fields, start, crop, pad):
def crop(fields, anchor, crop, pad):
new_fields = []
for x in fields:
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
begin, end = i - p0, i + c + p1
for d, (a, c, (p0, p1)) in enumerate(zip(anchor, crop, pad)):
begin, end = a - p0, a + c + p1
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
new_fields.append(x)

View File

@ -24,6 +24,9 @@ def test(args):
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,

View File

@ -68,6 +68,9 @@ def gpu_worker(local_rank, node, args):
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,
@ -107,6 +110,9 @@ def gpu_worker(local_rank, node, args):
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
cache=args.cache,