Add data caching, and new pad and crop features
This commit is contained in:
parent
d03bcb59a1
commit
843fe09a92
@ -28,12 +28,12 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--loader-workers', default=0, type=int,
|
parser.add_argument('--loader-workers', default=0, type=int,
|
||||||
help='number of data loading workers, per GPU in training or '
|
help='number of data loading workers, per GPU in training or '
|
||||||
'in total in testing')
|
'in total in testing')
|
||||||
parser.add_argument('--pad-or-crop', default=0, type=int_tuple,
|
parser.add_argument('--cache', action='store_true',
|
||||||
help='pad (>0) or crop (<0) the input data; '
|
help='enable caching in field datasets')
|
||||||
'can be a int or a 6-tuple (by a comma-sep. list); '
|
parser.add_argument('--crop', type=int,
|
||||||
'can be asymmetric to align the data with downsample '
|
help='size to crop the input and target data')
|
||||||
'and upsample convolutions; '
|
parser.add_argument('--pad', default=0, type=int,
|
||||||
'padding assumes periodic boundary condition')
|
help='pad the input data assuming periodic boundary condition')
|
||||||
|
|
||||||
|
|
||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
@ -80,11 +80,11 @@ def str_list(s):
|
|||||||
return s.split(',')
|
return s.split(',')
|
||||||
|
|
||||||
|
|
||||||
def int_tuple(t):
|
#def int_tuple(t):
|
||||||
t = t.split(',')
|
# t = t.split(',')
|
||||||
t = tuple(int(i) for i in t)
|
# t = tuple(int(i) for i in t)
|
||||||
if len(t) == 1:
|
# if len(t) == 1:
|
||||||
t = t[0]
|
# t = t[0]
|
||||||
elif len(t) != 6:
|
# elif len(t) != 6:
|
||||||
raise ValueError('pad or crop size must be int or 6-tuple')
|
# raise ValueError('size must be int or 6-tuple')
|
||||||
return t
|
# return t
|
||||||
|
@ -14,15 +14,15 @@ class FieldDataset(Dataset):
|
|||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target samples of all fields are matched by sorting the globbed files.
|
Input and target samples of all fields are matched by sorting the globbed files.
|
||||||
|
|
||||||
Input fields can be padded (>0) or cropped (<0) with `pad_or_crop`.
|
Input and target fields can be cached, and they can be cropped.
|
||||||
Padding assumes periodic boundary condition.
|
Input fields can be padded assuming periodic boundary condition.
|
||||||
|
|
||||||
Data augmentations are supported for scalar and vector fields.
|
Data augmentations are supported for scalar and vector fields.
|
||||||
|
|
||||||
`norms` can be a list of callables to normalize each field.
|
`norms` can be a list of callables to normalize each field.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns, pad_or_crop=0, augment=False,
|
def __init__(self, in_patterns, tgt_patterns, cache=False, crop=None, pad=0,
|
||||||
norms=None):
|
augment=False, norms=None):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
@ -35,47 +35,80 @@ class FieldDataset(Dataset):
|
|||||||
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
||||||
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
||||||
|
|
||||||
if isinstance(pad_or_crop, int):
|
self.size = np.load(self.in_files[0][0]).shape[-3:]
|
||||||
pad_or_crop = (pad_or_crop,) * 6
|
self.size = np.asarray(self.size)
|
||||||
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
|
self.ndim = len(self.size)
|
||||||
'pad or crop size must be int or 6-tuple'
|
|
||||||
self.pad_or_crop = np.array((0,) * 2 + pad_or_crop).reshape(4, 2)
|
self.cache = cache
|
||||||
|
if self.cache:
|
||||||
|
self.in_fields = []
|
||||||
|
self.tgt_fields = []
|
||||||
|
for idx in range(len(self.in_files)):
|
||||||
|
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
|
||||||
|
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
|
||||||
|
|
||||||
|
if crop is None:
|
||||||
|
self.crop = self.size
|
||||||
|
self.reps = np.ones_like(self.size)
|
||||||
|
else:
|
||||||
|
self.crop = np.broadcast_to(crop, self.size.shape)
|
||||||
|
self.reps = self.size // self.crop
|
||||||
|
self.tot_reps = int(np.prod(self.reps))
|
||||||
|
|
||||||
|
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||||
|
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||||
|
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
if self.ndim == 1 and self.augment:
|
||||||
|
raise ValueError('cannot augment 1D fields')
|
||||||
|
|
||||||
if norms is not None:
|
if norms is not None: # FIXME: in_norms, tgt_norms
|
||||||
assert len(in_patterns) == len(norms), \
|
assert len(in_patterns) == len(norms), \
|
||||||
'numbers of normalization callables and input fields do not match'
|
'numbers of normalization callables and input fields do not match'
|
||||||
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
||||||
self.norms = norms
|
self.norms = norms
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.in_files) * self.tot_reps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def channels(self):
|
def channels(self):
|
||||||
return self.in_channels, self.tgt_channels
|
return self.in_channels, self.tgt_channels
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.in_files)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
idx, sub_idx = idx // self.tot_reps, idx % self.tot_reps
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
start = np.unravel_index(sub_idx, self.reps) * self.crop
|
||||||
|
#print('==================================================')
|
||||||
|
#print(f'idx = {idx}, sub_idx = {sub_idx}, start = {start}')
|
||||||
|
#print(f'self.reps = {self.reps}, self.tot_reps = {self.tot_reps}')
|
||||||
|
#print(f'self.crop = {self.crop}, self.size = {self.size}')
|
||||||
|
#print(f'self.ndim = {self.ndim}, self.channels = {self.channels}')
|
||||||
|
#print(f'self.pad = {self.pad}')
|
||||||
|
|
||||||
padcrop(in_fields, self.pad_or_crop) # with numpy
|
if self.cache:
|
||||||
|
in_fields = self.in_fields[idx]
|
||||||
|
tgt_fields = self.tgt_fields[idx]
|
||||||
|
else:
|
||||||
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
|
|
||||||
|
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||||
|
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
flip_axes = torch.randint(2, (3,), dtype=torch.bool)
|
flip_axes = torch.randint(2, (self.ndim,), dtype=torch.bool)
|
||||||
flip_axes = torch.arange(3)[flip_axes]
|
flip_axes = torch.arange(self.ndim)[flip_axes]
|
||||||
|
|
||||||
flip3d(in_fields, flip_axes)
|
in_fields = flip(in_fields, flip_axes, self.ndim)
|
||||||
flip3d(tgt_fields, flip_axes)
|
tgt_fields = flip(tgt_fields, flip_axes, self.ndim)
|
||||||
|
|
||||||
perm_axes = torch.randperm(3)
|
perm_axes = torch.randperm(self.ndim)
|
||||||
|
|
||||||
perm3d(in_fields, perm_axes)
|
in_fields = perm(in_fields, perm_axes, self.ndim)
|
||||||
perm3d(tgt_fields, perm_axes)
|
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
|
||||||
|
|
||||||
if self.norms is not None:
|
if self.norms is not None:
|
||||||
for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
|
for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
|
||||||
@ -84,43 +117,49 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
#print(in_fields.shape, tgt_fields.shape)
|
||||||
|
|
||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
def padcrop(fields, width):
|
def crop(fields, start, crop, pad):
|
||||||
for i, x in enumerate(fields):
|
new_fields = []
|
||||||
if (width >= 0).all():
|
for x in fields:
|
||||||
x = np.pad(x, width, mode='wrap')
|
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||||
elif (width <= 0).all():
|
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
|
||||||
x = x[...,
|
|
||||||
-width[0, 0] : width[0, 1],
|
|
||||||
-width[1, 0] : width[1, 1],
|
|
||||||
-width[2, 0] : width[2, 1],
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('mixed pad-and-crop not supported')
|
|
||||||
|
|
||||||
fields[i] = x
|
new_fields.append(x)
|
||||||
|
|
||||||
|
return new_fields
|
||||||
|
|
||||||
|
|
||||||
def flip3d(fields, axes):
|
def flip(fields, axes, ndim):
|
||||||
for i, x in enumerate(fields):
|
assert ndim > 1, 'flipping is ambiguous for 1D vectors'
|
||||||
if x.size(0) == 3: # flip vector components
|
|
||||||
|
new_fields = []
|
||||||
|
for x in fields:
|
||||||
|
if x.size(0) == ndim: # flip vector components
|
||||||
x[axes] = - x[axes]
|
x[axes] = - x[axes]
|
||||||
|
|
||||||
axes = (1 + axes).tolist()
|
axes = (1 + axes).tolist()
|
||||||
x = torch.flip(x, axes)
|
x = torch.flip(x, axes)
|
||||||
|
|
||||||
fields[i] = x
|
new_fields.append(x)
|
||||||
|
|
||||||
|
return new_fields
|
||||||
|
|
||||||
|
|
||||||
def perm3d(fields, axes):
|
def perm(fields, axes, ndim):
|
||||||
for i, x in enumerate(fields):
|
assert ndim > 1, 'permutation is not necessary for 1D fields'
|
||||||
if x.size(0) == 3: # permutate vector components
|
|
||||||
|
new_fields = []
|
||||||
|
for x in fields:
|
||||||
|
if x.size(0) == ndim: # permutate vector components
|
||||||
x = x[axes]
|
x = x[axes]
|
||||||
|
|
||||||
axes = [0] + (1 + axes).tolist()
|
axes = [0] + (1 + axes).tolist()
|
||||||
x = x.permute(axes)
|
x = x.permute(axes)
|
||||||
|
|
||||||
fields[i] = x
|
new_fields.append(x)
|
||||||
|
|
||||||
|
return new_fields
|
||||||
|
@ -11,9 +11,11 @@ def test(args):
|
|||||||
test_dataset = FieldDataset(
|
test_dataset = FieldDataset(
|
||||||
in_patterns=args.test_in_patterns,
|
in_patterns=args.test_in_patterns,
|
||||||
tgt_patterns=args.test_tgt_patterns,
|
tgt_patterns=args.test_tgt_patterns,
|
||||||
|
cache=args.cache,
|
||||||
|
crop=args.crop,
|
||||||
|
pad=args.pad,
|
||||||
augment=False,
|
augment=False,
|
||||||
norms=args.norms,
|
norms=args.norms,
|
||||||
pad_or_crop=args.pad_or_crop,
|
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
@ -44,7 +46,7 @@ def test(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (input, target) in enumerate(test_loader):
|
for i, (input, target) in enumerate(test_loader):
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.pad_or_crop > 0: # FIXME
|
if args.pad > 0: # FIXME
|
||||||
output = narrow_like(output, target)
|
output = narrow_like(output, target)
|
||||||
input = narrow_like(input, target)
|
input = narrow_like(input, target)
|
||||||
else:
|
else:
|
||||||
|
@ -48,9 +48,11 @@ def gpu_worker(local_rank, args):
|
|||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
in_patterns=args.train_in_patterns,
|
in_patterns=args.train_in_patterns,
|
||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
|
cache=args.cache,
|
||||||
|
crop=args.crop,
|
||||||
|
pad=args.pad,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
norms=args.norms,
|
norms=args.norms,
|
||||||
pad_or_crop=args.pad_or_crop,
|
|
||||||
)
|
)
|
||||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
train_sampler = DistributedSampler(train_dataset)
|
train_sampler = DistributedSampler(train_dataset)
|
||||||
@ -66,9 +68,11 @@ def gpu_worker(local_rank, args):
|
|||||||
val_dataset = FieldDataset(
|
val_dataset = FieldDataset(
|
||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
|
cache=args.cache,
|
||||||
|
crop=args.crop,
|
||||||
|
pad=args.pad,
|
||||||
augment=False,
|
augment=False,
|
||||||
norms=args.norms,
|
norms=args.norms,
|
||||||
pad_or_crop=args.pad_or_crop,
|
|
||||||
)
|
)
|
||||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
val_sampler = DistributedSampler(val_dataset)
|
val_sampler = DistributedSampler(val_dataset)
|
||||||
|
@ -29,7 +29,7 @@ tgt_dir="nonlin"
|
|||||||
|
|
||||||
test_dirs="0" # FIXME
|
test_dirs="0" # FIXME
|
||||||
|
|
||||||
files="dis/128x???.npy"
|
files="dis/512x???.npy"
|
||||||
in_files="$files"
|
in_files="$files"
|
||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
@ -37,8 +37,8 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --model VNet \
|
--norms cosmology.dis --model VNet --cache --crop 128 --pad 50 \
|
||||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
--batches 1 --loader-workers 0 \
|
||||||
--load-state best_model.pth
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ tgt_dir="nonlin"
|
|||||||
|
|
||||||
test_dirs="0" # FIXME
|
test_dirs="0" # FIXME
|
||||||
|
|
||||||
files="vel/128x???.npy"
|
files="vel/512x???.npy"
|
||||||
in_files="$files"
|
in_files="$files"
|
||||||
tgt_files="$files"
|
tgt_files="$files"
|
||||||
|
|
||||||
@ -37,8 +37,8 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --model VNet \
|
--norms cosmology.vel --model VNet --cache --crop 128 --pad 50 \
|
||||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
--batches 1 --loader-workers 0 \
|
||||||
--load-state best_model.pth
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user