Change crop, flip, perm to in-place

This commit is contained in:
Yin Li 2020-09-22 22:55:55 -04:00
parent 156011be5f
commit e40ea52190

View File

@ -139,11 +139,11 @@ class FieldDataset(Dataset):
if shift is not None: if shift is not None:
anchor[d] += torch.randint(int(shift), (1,)) anchor[d] += torch.randint(int(shift), (1,))
in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size) crop(in_fields, anchor, self.crop, self.in_pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor, crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor,
self.tgt_pad, self.tgt_pad,
self.size * self.scale_factor) self.size * self.scale_factor)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
@ -158,11 +158,11 @@ class FieldDataset(Dataset):
norm(x) norm(x)
if self.augment: if self.augment:
in_fields, flip_axes = flip(in_fields, None, self.ndim) flip_axes = flip(in_fields, None, self.ndim)
tgt_fields, flip_axes = flip(tgt_fields, flip_axes, self.ndim) flip_axes = flip(tgt_fields, flip_axes, self.ndim)
in_fields, perm_axes = perm(in_fields, None, self.ndim) perm_axes = perm(in_fields, None, self.ndim)
tgt_fields, perm_axes = perm(tgt_fields, perm_axes, self.ndim) perm_axes = perm(tgt_fields, perm_axes, self.ndim)
if self.aug_add is not None: if self.aug_add is not None:
add_fac = add(in_fields, None, self.aug_add) add_fac = add(in_fields, None, self.aug_add)
@ -180,22 +180,22 @@ class FieldDataset(Dataset):
def crop(fields, anchor, crop, pad, size): def crop(fields, anchor, crop, pad, size):
ndim = len(size) ndim = len(size)
assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim' assert all(len(x) == ndim for x in [anchor, crop, pad]), 'ndim mismatch'
new_fields = [] ind = [slice(None)]
for x in fields: for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
ind = [slice(None)] i = np.arange(a - p0, a + c + p1)
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): i %= s
i = np.arange(a - p0, a + c + p1) i = i.reshape((-1,) + (1,) * (ndim - d - 1))
i %= s ind.append(i)
i = i.reshape((-1,) + (1,) * (ndim - d - 1)) ind = tuple(ind)
ind.append(i)
x = x[tuple(ind)] for i, x in enumerate(fields):
x = x[ind]
new_fields.append(x) fields[i] = x
return new_fields return ind
def flip(fields, axes, ndim): def flip(fields, axes, ndim):
@ -205,17 +205,16 @@ def flip(fields, axes, ndim):
axes = torch.randint(2, (ndim,), dtype=torch.bool) axes = torch.randint(2, (ndim,), dtype=torch.bool)
axes = torch.arange(ndim)[axes] axes = torch.arange(ndim)[axes]
new_fields = [] for i, x in enumerate(fields):
for x in fields:
if x.shape[0] == ndim: # flip vector components if x.shape[0] == ndim: # flip vector components
x[axes] = - x[axes] x[axes] = - x[axes]
shifted_axes = (1 + axes).tolist() shifted_axes = (1 + axes).tolist()
x = torch.flip(x, shifted_axes) x = torch.flip(x, shifted_axes)
new_fields.append(x) fields[i] = x
return new_fields, axes return axes
def perm(fields, axes, ndim): def perm(fields, axes, ndim):
@ -224,17 +223,16 @@ def perm(fields, axes, ndim):
if axes is None: if axes is None:
axes = torch.randperm(ndim) axes = torch.randperm(ndim)
new_fields = [] for i, x in enumerate(fields):
for x in fields:
if x.shape[0] == ndim: # permutate vector components if x.shape[0] == ndim: # permutate vector components
x = x[axes] x = x[axes]
shifted_axes = [0] + (1 + axes).tolist() shifted_axes = [0] + (1 + axes).tolist()
x = x.permute(shifted_axes) x = x.permute(shifted_axes)
new_fields.append(x) fields[i] = x
return new_fields, axes return axes
def add(fields, fac, std): def add(fields, fac, std):