Change crop, flip, perm to in-place
This commit is contained in:
parent
156011be5f
commit
e40ea52190
@ -139,11 +139,11 @@ class FieldDataset(Dataset):
|
|||||||
if shift is not None:
|
if shift is not None:
|
||||||
anchor[d] += torch.randint(int(shift), (1,))
|
anchor[d] += torch.randint(int(shift), (1,))
|
||||||
|
|
||||||
in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
||||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
self.tgt_pad,
|
self.tgt_pad,
|
||||||
self.size * self.scale_factor)
|
self.size * self.scale_factor)
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
@ -158,11 +158,11 @@ class FieldDataset(Dataset):
|
|||||||
norm(x)
|
norm(x)
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
in_fields, flip_axes = flip(in_fields, None, self.ndim)
|
flip_axes = flip(in_fields, None, self.ndim)
|
||||||
tgt_fields, flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
||||||
|
|
||||||
in_fields, perm_axes = perm(in_fields, None, self.ndim)
|
perm_axes = perm(in_fields, None, self.ndim)
|
||||||
tgt_fields, perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
||||||
|
|
||||||
if self.aug_add is not None:
|
if self.aug_add is not None:
|
||||||
add_fac = add(in_fields, None, self.aug_add)
|
add_fac = add(in_fields, None, self.aug_add)
|
||||||
@ -180,22 +180,22 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
def crop(fields, anchor, crop, pad, size):
|
def crop(fields, anchor, crop, pad, size):
|
||||||
ndim = len(size)
|
ndim = len(size)
|
||||||
assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim'
|
assert all(len(x) == ndim for x in [anchor, crop, pad]), 'ndim mismatch'
|
||||||
|
|
||||||
new_fields = []
|
ind = [slice(None)]
|
||||||
for x in fields:
|
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
||||||
ind = [slice(None)]
|
i = np.arange(a - p0, a + c + p1)
|
||||||
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
i %= s
|
||||||
i = np.arange(a - p0, a + c + p1)
|
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||||
i %= s
|
ind.append(i)
|
||||||
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
ind = tuple(ind)
|
||||||
ind.append(i)
|
|
||||||
|
|
||||||
x = x[tuple(ind)]
|
for i, x in enumerate(fields):
|
||||||
|
x = x[ind]
|
||||||
|
|
||||||
new_fields.append(x)
|
fields[i] = x
|
||||||
|
|
||||||
return new_fields
|
return ind
|
||||||
|
|
||||||
|
|
||||||
def flip(fields, axes, ndim):
|
def flip(fields, axes, ndim):
|
||||||
@ -205,17 +205,16 @@ def flip(fields, axes, ndim):
|
|||||||
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
||||||
axes = torch.arange(ndim)[axes]
|
axes = torch.arange(ndim)[axes]
|
||||||
|
|
||||||
new_fields = []
|
for i, x in enumerate(fields):
|
||||||
for x in fields:
|
|
||||||
if x.shape[0] == ndim: # flip vector components
|
if x.shape[0] == ndim: # flip vector components
|
||||||
x[axes] = - x[axes]
|
x[axes] = - x[axes]
|
||||||
|
|
||||||
shifted_axes = (1 + axes).tolist()
|
shifted_axes = (1 + axes).tolist()
|
||||||
x = torch.flip(x, shifted_axes)
|
x = torch.flip(x, shifted_axes)
|
||||||
|
|
||||||
new_fields.append(x)
|
fields[i] = x
|
||||||
|
|
||||||
return new_fields, axes
|
return axes
|
||||||
|
|
||||||
|
|
||||||
def perm(fields, axes, ndim):
|
def perm(fields, axes, ndim):
|
||||||
@ -224,17 +223,16 @@ def perm(fields, axes, ndim):
|
|||||||
if axes is None:
|
if axes is None:
|
||||||
axes = torch.randperm(ndim)
|
axes = torch.randperm(ndim)
|
||||||
|
|
||||||
new_fields = []
|
for i, x in enumerate(fields):
|
||||||
for x in fields:
|
|
||||||
if x.shape[0] == ndim: # permutate vector components
|
if x.shape[0] == ndim: # permutate vector components
|
||||||
x = x[axes]
|
x = x[axes]
|
||||||
|
|
||||||
shifted_axes = [0] + (1 + axes).tolist()
|
shifted_axes = [0] + (1 + axes).tolist()
|
||||||
x = x.permute(shifted_axes)
|
x = x.permute(shifted_axes)
|
||||||
|
|
||||||
new_fields.append(x)
|
fields[i] = x
|
||||||
|
|
||||||
return new_fields, axes
|
return axes
|
||||||
|
|
||||||
|
|
||||||
def add(fields, fac, std):
|
def add(fields, fac, std):
|
||||||
|
Loading…
Reference in New Issue
Block a user