Add data caching, and new pad and crop features

This commit is contained in:
Yin Li 2019-12-17 12:00:13 -05:00
parent d03bcb59a1
commit 843fe09a92
6 changed files with 113 additions and 68 deletions

View File

@ -28,12 +28,12 @@ def add_common_args(parser):
parser.add_argument('--loader-workers', default=0, type=int,
help='number of data loading workers, per GPU in training or '
'in total in testing')
parser.add_argument('--pad-or-crop', default=0, type=int_tuple,
help='pad (>0) or crop (<0) the input data; '
'can be a int or a 6-tuple (by a comma-sep. list); '
'can be asymmetric to align the data with downsample '
'and upsample convolutions; '
'padding assumes periodic boundary condition')
parser.add_argument('--cache', action='store_true',
help='enable caching in field datasets')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data')
parser.add_argument('--pad', default=0, type=int,
help='pad the input data assuming periodic boundary condition')
def add_train_args(parser):
@ -80,11 +80,11 @@ def str_list(s):
return s.split(',')
def int_tuple(t):
t = t.split(',')
t = tuple(int(i) for i in t)
if len(t) == 1:
t = t[0]
elif len(t) != 6:
raise ValueError('pad or crop size must be int or 6-tuple')
return t
#def int_tuple(t):
# t = t.split(',')
# t = tuple(int(i) for i in t)
# if len(t) == 1:
# t = t[0]
# elif len(t) != 6:
# raise ValueError('size must be int or 6-tuple')
# return t

View File

@ -14,15 +14,15 @@ class FieldDataset(Dataset):
Likewise `tgt_patterns` is for target fields.
Input and target samples of all fields are matched by sorting the globbed files.
Input fields can be padded (>0) or cropped (<0) with `pad_or_crop`.
Padding assumes periodic boundary condition.
Input and target fields can be cached, and they can be cropped.
Input fields can be padded assuming periodic boundary condition.
Data augmentations are supported for scalar and vector fields.
`norms` can be a list of callables to normalize each field.
"""
def __init__(self, in_patterns, tgt_patterns, pad_or_crop=0, augment=False,
norms=None):
def __init__(self, in_patterns, tgt_patterns, cache=False, crop=None, pad=0,
augment=False, norms=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -35,47 +35,80 @@ class FieldDataset(Dataset):
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
if isinstance(pad_or_crop, int):
pad_or_crop = (pad_or_crop,) * 6
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
'pad or crop size must be int or 6-tuple'
self.pad_or_crop = np.array((0,) * 2 + pad_or_crop).reshape(4, 2)
self.size = np.load(self.in_files[0][0]).shape[-3:]
self.size = np.asarray(self.size)
self.ndim = len(self.size)
self.cache = cache
if self.cache:
self.in_fields = []
self.tgt_fields = []
for idx in range(len(self.in_files)):
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
if crop is None:
self.crop = self.size
self.reps = np.ones_like(self.size)
else:
self.crop = np.broadcast_to(crop, self.size.shape)
self.reps = self.size // self.crop
self.tot_reps = int(np.prod(self.reps))
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
if norms is not None:
if norms is not None: # FIXME: in_norms, tgt_norms
assert len(in_patterns) == len(norms), \
'numbers of normalization callables and input fields do not match'
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms
def __len__(self):
return len(self.in_files) * self.tot_reps
@property
def channels(self):
return self.in_channels, self.tgt_channels
def __len__(self):
return len(self.in_files)
def __getitem__(self, idx):
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
start = np.unravel_index(sub_idx, self.reps) * self.crop
#print('==================================================')
#print(f'idx = {idx}, sub_idx = {sub_idx}, start = {start}')
#print(f'self.reps = {self.reps}, self.tot_reps = {self.tot_reps}')
#print(f'self.crop = {self.crop}, self.size = {self.size}')
#print(f'self.ndim = {self.ndim}, self.channels = {self.channels}')
#print(f'self.pad = {self.pad}')
if self.cache:
in_fields = self.in_fields[idx]
tgt_fields = self.tgt_fields[idx]
else:
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
padcrop(in_fields, self.pad_or_crop) # with numpy
in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
if self.augment:
flip_axes = torch.randint(2, (3,), dtype=torch.bool)
flip_axes = torch.arange(3)[flip_axes]
flip_axes = torch.randint(2, (self.ndim,), dtype=torch.bool)
flip_axes = torch.arange(self.ndim)[flip_axes]
flip3d(in_fields, flip_axes)
flip3d(tgt_fields, flip_axes)
in_fields = flip(in_fields, flip_axes, self.ndim)
tgt_fields = flip(tgt_fields, flip_axes, self.ndim)
perm_axes = torch.randperm(3)
perm_axes = torch.randperm(self.ndim)
perm3d(in_fields, perm_axes)
perm3d(tgt_fields, perm_axes)
in_fields = perm(in_fields, perm_axes, self.ndim)
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
if self.norms is not None:
for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
@ -84,43 +117,49 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
#print(in_fields.shape, tgt_fields.shape)
return in_fields, tgt_fields
def padcrop(fields, width):
for i, x in enumerate(fields):
if (width >= 0).all():
x = np.pad(x, width, mode='wrap')
elif (width <= 0).all():
x = x[...,
-width[0, 0] : width[0, 1],
-width[1, 0] : width[1, 1],
-width[2, 0] : width[2, 1],
]
else:
raise NotImplementedError('mixed pad-and-crop not supported')
def crop(fields, start, crop, pad):
new_fields = []
for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
fields[i] = x
new_fields.append(x)
return new_fields
def flip3d(fields, axes):
for i, x in enumerate(fields):
if x.size(0) == 3: # flip vector components
def flip(fields, axes, ndim):
assert ndim > 1, 'flipping is ambiguous for 1D vectors'
new_fields = []
for x in fields:
if x.size(0) == ndim: # flip vector components
x[axes] = - x[axes]
axes = (1 + axes).tolist()
x = torch.flip(x, axes)
fields[i] = x
new_fields.append(x)
return new_fields
def perm3d(fields, axes):
for i, x in enumerate(fields):
if x.size(0) == 3: # permutate vector components
def perm(fields, axes, ndim):
assert ndim > 1, 'permutation is not necessary for 1D fields'
new_fields = []
for x in fields:
if x.size(0) == ndim: # permutate vector components
x = x[axes]
axes = [0] + (1 + axes).tolist()
x = x.permute(axes)
fields[i] = x
new_fields.append(x)
return new_fields

View File

@ -11,9 +11,11 @@ def test(args):
test_dataset = FieldDataset(
in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False,
norms=args.norms,
pad_or_crop=args.pad_or_crop,
)
test_loader = DataLoader(
test_dataset,
@ -44,7 +46,7 @@ def test(args):
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
output = model(input)
if args.pad_or_crop > 0: # FIXME
if args.pad > 0: # FIXME
output = narrow_like(output, target)
input = narrow_like(input, target)
else:

View File

@ -48,9 +48,11 @@ def gpu_worker(local_rank, args):
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=args.augment,
norms=args.norms,
pad_or_crop=args.pad_or_crop,
)
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_sampler = DistributedSampler(train_dataset)
@ -66,9 +68,11 @@ def gpu_worker(local_rank, args):
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False,
norms=args.norms,
pad_or_crop=args.pad_or_crop,
)
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_sampler = DistributedSampler(val_dataset)

View File

@ -29,7 +29,7 @@ tgt_dir="nonlin"
test_dirs="0" # FIXME
files="dis/128x???.npy"
files="dis/512x???.npy"
in_files="$files"
tgt_files="$files"
@ -37,8 +37,8 @@ tgt_files="$files"
m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.dis --model VNet \
--batches 1 --loader-workers 0 --pad-or-crop 40 \
--norms cosmology.dis --model VNet --cache --crop 128 --pad 50 \
--batches 1 --loader-workers 0 \
--load-state best_model.pth

View File

@ -29,7 +29,7 @@ tgt_dir="nonlin"
test_dirs="0" # FIXME
files="vel/128x???.npy"
files="vel/512x???.npy"
in_files="$files"
tgt_files="$files"
@ -37,8 +37,8 @@ tgt_files="$files"
m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.vel --model VNet \
--batches 1 --loader-workers 0 --pad-or-crop 40 \
--norms cosmology.vel --model VNet --cache --crop 128 --pad 50 \
--batches 1 --loader-workers 0 \
--load-state best_model.pth