Add data caching, and new pad and crop features

This commit is contained in:
Yin Li 2019-12-17 12:00:13 -05:00
parent d03bcb59a1
commit 843fe09a92
6 changed files with 113 additions and 68 deletions

View File

@ -28,12 +28,12 @@ def add_common_args(parser):
parser.add_argument('--loader-workers', default=0, type=int, parser.add_argument('--loader-workers', default=0, type=int,
help='number of data loading workers, per GPU in training or ' help='number of data loading workers, per GPU in training or '
'in total in testing') 'in total in testing')
parser.add_argument('--pad-or-crop', default=0, type=int_tuple, parser.add_argument('--cache', action='store_true',
help='pad (>0) or crop (<0) the input data; ' help='enable caching in field datasets')
'can be a int or a 6-tuple (by a comma-sep. list); ' parser.add_argument('--crop', type=int,
'can be asymmetric to align the data with downsample ' help='size to crop the input and target data')
'and upsample convolutions; ' parser.add_argument('--pad', default=0, type=int,
'padding assumes periodic boundary condition') help='pad the input data assuming periodic boundary condition')
def add_train_args(parser): def add_train_args(parser):
@ -80,11 +80,11 @@ def str_list(s):
return s.split(',') return s.split(',')
def int_tuple(t): #def int_tuple(t):
t = t.split(',') # t = t.split(',')
t = tuple(int(i) for i in t) # t = tuple(int(i) for i in t)
if len(t) == 1: # if len(t) == 1:
t = t[0] # t = t[0]
elif len(t) != 6: # elif len(t) != 6:
raise ValueError('pad or crop size must be int or 6-tuple') # raise ValueError('size must be int or 6-tuple')
return t # return t

View File

@ -14,15 +14,15 @@ class FieldDataset(Dataset):
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target samples of all fields are matched by sorting the globbed files. Input and target samples of all fields are matched by sorting the globbed files.
Input fields can be padded (>0) or cropped (<0) with `pad_or_crop`. Input and target fields can be cached, and they can be cropped.
Padding assumes periodic boundary condition. Input fields can be padded assuming periodic boundary condition.
Data augmentations are supported for scalar and vector fields. Data augmentations are supported for scalar and vector fields.
`norms` can be a list of callables to normalize each field. `norms` can be a list of callables to normalize each field.
""" """
def __init__(self, in_patterns, tgt_patterns, pad_or_crop=0, augment=False, def __init__(self, in_patterns, tgt_patterns, cache=False, crop=None, pad=0,
norms=None): augment=False, norms=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))
@ -35,47 +35,80 @@ class FieldDataset(Dataset):
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0]) self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0]) self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
if isinstance(pad_or_crop, int): self.size = np.load(self.in_files[0][0]).shape[-3:]
pad_or_crop = (pad_or_crop,) * 6 self.size = np.asarray(self.size)
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \ self.ndim = len(self.size)
'pad or crop size must be int or 6-tuple'
self.pad_or_crop = np.array((0,) * 2 + pad_or_crop).reshape(4, 2) self.cache = cache
if self.cache:
self.in_fields = []
self.tgt_fields = []
for idx in range(len(self.in_files)):
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
if crop is None:
self.crop = self.size
self.reps = np.ones_like(self.size)
else:
self.crop = np.broadcast_to(crop, self.size.shape)
self.reps = self.size // self.crop
self.tot_reps = int(np.prod(self.reps))
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
self.augment = augment self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
if norms is not None: if norms is not None: # FIXME: in_norms, tgt_norms
assert len(in_patterns) == len(norms), \ assert len(in_patterns) == len(norms), \
'numbers of normalization callables and input fields do not match' 'numbers of normalization callables and input fields do not match'
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)] norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms self.norms = norms
def __len__(self):
return len(self.in_files) * self.tot_reps
@property @property
def channels(self): def channels(self):
return self.in_channels, self.tgt_channels return self.in_channels, self.tgt_channels
def __len__(self):
return len(self.in_files)
def __getitem__(self, idx): def __getitem__(self, idx):
in_fields = [np.load(f) for f in self.in_files[idx]] idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
tgt_fields = [np.load(f) for f in self.tgt_files[idx]] start = np.unravel_index(sub_idx, self.reps) * self.crop
#print('==================================================')
#print(f'idx = {idx}, sub_idx = {sub_idx}, start = {start}')
#print(f'self.reps = {self.reps}, self.tot_reps = {self.tot_reps}')
#print(f'self.crop = {self.crop}, self.size = {self.size}')
#print(f'self.ndim = {self.ndim}, self.channels = {self.channels}')
#print(f'self.pad = {self.pad}')
padcrop(in_fields, self.pad_or_crop) # with numpy if self.cache:
in_fields = self.in_fields[idx]
tgt_fields = self.tgt_fields[idx]
else:
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
if self.augment: if self.augment:
flip_axes = torch.randint(2, (3,), dtype=torch.bool) flip_axes = torch.randint(2, (self.ndim,), dtype=torch.bool)
flip_axes = torch.arange(3)[flip_axes] flip_axes = torch.arange(self.ndim)[flip_axes]
flip3d(in_fields, flip_axes) in_fields = flip(in_fields, flip_axes, self.ndim)
flip3d(tgt_fields, flip_axes) tgt_fields = flip(tgt_fields, flip_axes, self.ndim)
perm_axes = torch.randperm(3) perm_axes = torch.randperm(self.ndim)
perm3d(in_fields, perm_axes) in_fields = perm(in_fields, perm_axes, self.ndim)
perm3d(tgt_fields, perm_axes) tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
if self.norms is not None: if self.norms is not None:
for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields): for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
@ -84,43 +117,49 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)
#print(in_fields.shape, tgt_fields.shape)
return in_fields, tgt_fields return in_fields, tgt_fields
def padcrop(fields, width): def crop(fields, start, crop, pad):
for i, x in enumerate(fields): new_fields = []
if (width >= 0).all(): for x in fields:
x = np.pad(x, width, mode='wrap') for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
elif (width <= 0).all(): x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
x = x[...,
-width[0, 0] : width[0, 1],
-width[1, 0] : width[1, 1],
-width[2, 0] : width[2, 1],
]
else:
raise NotImplementedError('mixed pad-and-crop not supported')
fields[i] = x new_fields.append(x)
return new_fields
def flip3d(fields, axes): def flip(fields, axes, ndim):
for i, x in enumerate(fields): assert ndim > 1, 'flipping is ambiguous for 1D vectors'
if x.size(0) == 3: # flip vector components
new_fields = []
for x in fields:
if x.size(0) == ndim: # flip vector components
x[axes] = - x[axes] x[axes] = - x[axes]
axes = (1 + axes).tolist() axes = (1 + axes).tolist()
x = torch.flip(x, axes) x = torch.flip(x, axes)
fields[i] = x new_fields.append(x)
return new_fields
def perm3d(fields, axes): def perm(fields, axes, ndim):
for i, x in enumerate(fields): assert ndim > 1, 'permutation is not necessary for 1D fields'
if x.size(0) == 3: # permutate vector components
new_fields = []
for x in fields:
if x.size(0) == ndim: # permutate vector components
x = x[axes] x = x[axes]
axes = [0] + (1 + axes).tolist() axes = [0] + (1 + axes).tolist()
x = x.permute(axes) x = x.permute(axes)
fields[i] = x new_fields.append(x)
return new_fields

View File

@ -11,9 +11,11 @@ def test(args):
test_dataset = FieldDataset( test_dataset = FieldDataset(
in_patterns=args.test_in_patterns, in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns, tgt_patterns=args.test_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False, augment=False,
norms=args.norms, norms=args.norms,
pad_or_crop=args.pad_or_crop,
) )
test_loader = DataLoader( test_loader = DataLoader(
test_dataset, test_dataset,
@ -44,7 +46,7 @@ def test(args):
with torch.no_grad(): with torch.no_grad():
for i, (input, target) in enumerate(test_loader): for i, (input, target) in enumerate(test_loader):
output = model(input) output = model(input)
if args.pad_or_crop > 0: # FIXME if args.pad > 0: # FIXME
output = narrow_like(output, target) output = narrow_like(output, target)
input = narrow_like(input, target) input = narrow_like(input, target)
else: else:

View File

@ -48,9 +48,11 @@ def gpu_worker(local_rank, args):
train_dataset = FieldDataset( train_dataset = FieldDataset(
in_patterns=args.train_in_patterns, in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns, tgt_patterns=args.train_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=args.augment, augment=args.augment,
norms=args.norms, norms=args.norms,
pad_or_crop=args.pad_or_crop,
) )
#train_sampler = DistributedSampler(train_dataset, shuffle=True) #train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_sampler = DistributedSampler(train_dataset) train_sampler = DistributedSampler(train_dataset)
@ -66,9 +68,11 @@ def gpu_worker(local_rank, args):
val_dataset = FieldDataset( val_dataset = FieldDataset(
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
cache=args.cache,
crop=args.crop,
pad=args.pad,
augment=False, augment=False,
norms=args.norms, norms=args.norms,
pad_or_crop=args.pad_or_crop,
) )
#val_sampler = DistributedSampler(val_dataset, shuffle=False) #val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_sampler = DistributedSampler(val_dataset) val_sampler = DistributedSampler(val_dataset)

View File

@ -29,7 +29,7 @@ tgt_dir="nonlin"
test_dirs="0" # FIXME test_dirs="0" # FIXME
files="dis/128x???.npy" files="dis/512x???.npy"
in_files="$files" in_files="$files"
tgt_files="$files" tgt_files="$files"
@ -37,8 +37,8 @@ tgt_files="$files"
m2m.py test \ m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.dis --model VNet \ --norms cosmology.dis --model VNet --cache --crop 128 --pad 50 \
--batches 1 --loader-workers 0 --pad-or-crop 40 \ --batches 1 --loader-workers 0 \
--load-state best_model.pth --load-state best_model.pth

View File

@ -29,7 +29,7 @@ tgt_dir="nonlin"
test_dirs="0" # FIXME test_dirs="0" # FIXME
files="vel/128x???.npy" files="vel/512x???.npy"
in_files="$files" in_files="$files"
tgt_files="$files" tgt_files="$files"
@ -37,8 +37,8 @@ tgt_files="$files"
m2m.py test \ m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.vel --model VNet \ --norms cosmology.vel --model VNet --cache --crop 128 --pad 50 \
--batches 1 --loader-workers 0 --pad-or-crop 40 \ --batches 1 --loader-workers 0 \
--load-state best_model.pth --load-state best_model.pth