Add cropping anchors controlled by start, stop, step
This commit is contained in:
parent
819e77cd86
commit
8a95d69818
15
README.md
15
README.md
@ -33,19 +33,22 @@ For all command line options look at `map2map/args.py` or do `m2m.py -h`.
|
|||||||
|
|
||||||
Put each field in one npy file.
|
Put each field in one npy file.
|
||||||
Structure your data to start with the channel axis and then the spatial
|
Structure your data to start with the channel axis and then the spatial
|
||||||
dimensions.
|
dimensions, e.g. `(2, 64, 64)` for a 2D vector field of size `64^2` and
|
||||||
For example a 2D vector field of size `64^2` should have shape `(2, 64,
|
`(1, 32, 32, 32)` for a 3D scalar field of size `32^3`.
|
||||||
64)`.
|
|
||||||
Specify the data path with
|
Specify the data path with
|
||||||
[glob patterns](https://docs.python.org/3/library/glob.html).
|
[glob patterns](https://docs.python.org/3/library/glob.html).
|
||||||
|
|
||||||
During training, pairs of input and target fields are loaded.
|
During training, pairs of input and target fields are loaded.
|
||||||
Both input and target data can consist of multiple fields, which are
|
Both input and target data can consist of multiple fields, which are
|
||||||
then concatenated along the channel axis.
|
then concatenated along the channel axis.
|
||||||
|
|
||||||
|
|
||||||
|
#### Data cropping
|
||||||
|
|
||||||
If the size of a pair of input and target fields is too large to fit in
|
If the size of a pair of input and target fields is too large to fit in
|
||||||
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
|
a GPU, we can crop part of them to form pairs of samples.
|
||||||
Each field can be cropped multiple times, along each dimension,
|
Each field can be cropped multiple times, along each dimension.
|
||||||
controlled by the spacing between two adjacent crops (see `--step`).
|
See `--crop`, `--crop-start`, `--crop-stop`, and `--crop-step`.
|
||||||
The total sample size is the number of input and target pairs multiplied
|
The total sample size is the number of input and target pairs multiplied
|
||||||
by the number of cropped samples per pair.
|
by the number of cropped samples per pair.
|
||||||
|
|
||||||
|
@ -41,14 +41,21 @@ def add_common_args(parser):
|
|||||||
'of target normalization functions from .data.norms')
|
'of target normalization functions from .data.norms')
|
||||||
parser.add_argument('--crop', type=int,
|
parser.add_argument('--crop', type=int,
|
||||||
help='size to crop the input and target data')
|
help='size to crop the input and target data')
|
||||||
|
parser.add_argument('--crop-start', type=int,
|
||||||
|
help='starting point of the first crop. Default is the origin')
|
||||||
|
parser.add_argument('--crop-stop', type=int,
|
||||||
|
help='stopping point of the last crop. Default is the corner '
|
||||||
|
'opposite to the origin')
|
||||||
|
parser.add_argument('--crop-step', type=int,
|
||||||
|
help='spacing between crops. Default is the crop size')
|
||||||
parser.add_argument('--pad', default=0, type=int,
|
parser.add_argument('--pad', default=0, type=int,
|
||||||
help='size to pad the input data beyond the crop size, assuming '
|
help='size to pad the input data beyond the crop size, assuming '
|
||||||
'periodic boundary condition')
|
'periodic boundary condition')
|
||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
help='input upsampling factor for super-resolution purpose, in '
|
help='upsampling factor for super-resolution, in which case '
|
||||||
'which case crop and pad will be taken at the original resolution')
|
'crop and pad are sizes of the input resolution')
|
||||||
|
|
||||||
parser.add_argument('--model', required=True, type=str,
|
parser.add_argument('--model', type=str, required=True,
|
||||||
help='model from .models')
|
help='model from .models')
|
||||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
parser.add_argument('--criterion', default='MSELoss', type=str,
|
||||||
help='model criterion from torch.nn')
|
help='model criterion from torch.nn')
|
||||||
@ -124,7 +131,7 @@ def add_train_args(parser):
|
|||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
parser.add_argument('--lr', default=0.001, type=float,
|
parser.add_argument('--lr', type=float, required=True,
|
||||||
help='initial learning rate')
|
help='initial learning rate')
|
||||||
# parser.add_argument('--momentum', default=0.9, type=float,
|
# parser.add_argument('--momentum', default=0.9, type=float,
|
||||||
# help='momentum')
|
# help='momentum')
|
||||||
|
@ -14,6 +14,7 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
`in_patterns` is a list of glob patterns for the input field files.
|
`in_patterns` is a list of glob patterns for the input field files.
|
||||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||||
|
Each pattern in the list is a new field.
|
||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target fields are matched by sorting the globbed files.
|
Input and target fields are matched by sorting the globbed files.
|
||||||
|
|
||||||
@ -25,8 +26,11 @@ class FieldDataset(Dataset):
|
|||||||
Additive and multiplicative augmentation are also possible, but with all fields
|
Additive and multiplicative augmentation are also possible, but with all fields
|
||||||
added or multiplied by the same factor.
|
added or multiplied by the same factor.
|
||||||
|
|
||||||
Input and target fields can be cropped.
|
Input and target fields can be cropped, to return multiple slices of size
|
||||||
Input fields can be padded assuming periodic boundary condition.
|
`crop` from each field.
|
||||||
|
The crop anchors are controlled by `crop_start`, `crop_stop`, and `crop_step`.
|
||||||
|
Input (but not target) fields can be padded beyond the crop size assuming
|
||||||
|
periodic boundary condition.
|
||||||
|
|
||||||
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
@ -41,7 +45,8 @@ class FieldDataset(Dataset):
|
|||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_add=None, aug_mul=None,
|
augment=False, aug_add=None, aug_mul=None,
|
||||||
crop=None, pad=0, scale_factor=1,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
|
pad=0, scale_factor=1,
|
||||||
cache=False, cache_maxsize=None, div_data=False,
|
cache=False, cache_maxsize=None, div_data=False,
|
||||||
rank=None, world_size=None):
|
rank=None, world_size=None):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
@ -54,7 +59,7 @@ class FieldDataset(Dataset):
|
|||||||
'number of input and target fields do not match'
|
'number of input and target fields do not match'
|
||||||
self.nfile = len(self.in_files)
|
self.nfile = len(self.in_files)
|
||||||
|
|
||||||
assert self.nfile > 0, 'file not found'
|
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
||||||
|
|
||||||
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
self.in_chan = [np.load(f).shape[0] for f in self.in_files[0]]
|
||||||
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
self.tgt_chan = [np.load(f).shape[0] for f in self.tgt_files[0]]
|
||||||
@ -85,11 +90,29 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
if crop is None:
|
if crop is None:
|
||||||
self.crop = self.size
|
self.crop = self.size
|
||||||
self.reps = np.ones_like(self.size)
|
|
||||||
else:
|
else:
|
||||||
self.crop = np.broadcast_to(crop, self.size.shape)
|
self.crop = np.broadcast_to(crop, (self.ndim,))
|
||||||
self.reps = self.size // self.crop
|
|
||||||
self.ncrop = int(np.prod(self.reps))
|
if crop_start is None:
|
||||||
|
crop_start = np.zeros_like(self.size)
|
||||||
|
else:
|
||||||
|
crop_start = np.broadcast_to(crop_start, (self.ndim,))
|
||||||
|
|
||||||
|
if crop_stop is None:
|
||||||
|
crop_stop = self.size
|
||||||
|
else:
|
||||||
|
crop_stop = np.broadcast_to(crop_stop, (self.ndim,))
|
||||||
|
|
||||||
|
if crop_step is None:
|
||||||
|
crop_step = self.crop
|
||||||
|
else:
|
||||||
|
crop_step = np.broadcast_to(crop_step, (self.ndim,))
|
||||||
|
|
||||||
|
self.anchors = np.stack(np.mgrid[tuple(
|
||||||
|
slice(crop_start[d], crop_stop[d], crop_step[d])
|
||||||
|
for d in range(self.ndim)
|
||||||
|
)], axis=-1).reshape(-1, self.ndim)
|
||||||
|
self.ncrop = len(self.anchors)
|
||||||
|
|
||||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||||
@ -138,10 +161,10 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
|
in_fields, tgt_fields = self.get_fields(idx // self.ncrop)
|
||||||
|
|
||||||
start = np.unravel_index(idx % self.ncrop, self.reps) * self.crop
|
anchor = self.anchors[idx % self.ncrop]
|
||||||
|
|
||||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
in_fields = crop(in_fields, anchor, self.crop, self.pad)
|
||||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
np.zeros_like(self.pad))
|
np.zeros_like(self.pad))
|
||||||
|
|
||||||
@ -176,11 +199,11 @@ class FieldDataset(Dataset):
|
|||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, start, crop, pad):
|
def crop(fields, anchor, crop, pad):
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
for d, (a, c, (p0, p1)) in enumerate(zip(anchor, crop, pad)):
|
||||||
begin, end = i - p0, i + c + p1
|
begin, end = a - p0, a + c + p1
|
||||||
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
@ -24,6 +24,9 @@ def test(args):
|
|||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
crop_start=args.crop_start,
|
||||||
|
crop_stop=args.crop_stop,
|
||||||
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
|
@ -68,6 +68,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
aug_add=args.aug_add,
|
aug_add=args.aug_add,
|
||||||
aug_mul=args.aug_mul,
|
aug_mul=args.aug_mul,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
crop_start=args.crop_start,
|
||||||
|
crop_stop=args.crop_stop,
|
||||||
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
@ -107,6 +110,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
crop_start=args.crop_start,
|
||||||
|
crop_stop=args.crop_stop,
|
||||||
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
|
Loading…
Reference in New Issue
Block a user