Remove size argument from crop
This commit is contained in:
parent
002e249925
commit
08c805e47a
@ -162,11 +162,9 @@ class FieldDataset(Dataset):
|
|||||||
if shift is not None:
|
if shift is not None:
|
||||||
anchor[d] += torch.randint(int(shift), (1,))
|
anchor[d] += torch.randint(int(shift), (1,))
|
||||||
|
|
||||||
crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
crop(in_fields, anchor, self.crop, self.in_pad)
|
||||||
crop(tgt_fields, anchor * self.scale_factor,
|
crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor, self.tgt_pad)
|
||||||
self.tgt_pad,
|
|
||||||
self.size * self.scale_factor)
|
|
||||||
|
|
||||||
style = torch.from_numpy(style).to(torch.float32)
|
style = torch.from_numpy(style).to(torch.float32)
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
@ -283,6 +281,7 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
def fill(field, patch, anchor):
|
def fill(field, patch, anchor):
|
||||||
ndim = len(anchor)
|
ndim = len(anchor)
|
||||||
|
assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch'
|
||||||
|
|
||||||
ind = [slice(None)]
|
ind = [slice(None)]
|
||||||
for d, (p, a, s) in enumerate(zip(
|
for d, (p, a, s) in enumerate(zip(
|
||||||
@ -296,9 +295,11 @@ def fill(field, patch, anchor):
|
|||||||
field[ind] = patch
|
field[ind] = patch
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, anchor, crop, pad, size):
|
def crop(fields, anchor, crop, pad):
|
||||||
|
assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch'
|
||||||
|
size = fields[0].shape[1:]
|
||||||
ndim = len(size)
|
ndim = len(size)
|
||||||
assert all(len(x) == ndim for x in [anchor, crop, pad]), 'ndim mismatch'
|
assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch'
|
||||||
|
|
||||||
ind = [slice(None)]
|
ind = [slice(None)]
|
||||||
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
||||||
|
Loading…
Reference in New Issue
Block a user