Remove size argument from crop

This commit is contained in:
Yin Li 2021-04-12 12:53:50 -04:00
parent 002e249925
commit 08c805e47a

View File

@ -162,11 +162,9 @@ class FieldDataset(Dataset):
if shift is not None: if shift is not None:
anchor[d] += torch.randint(int(shift), (1,)) anchor[d] += torch.randint(int(shift), (1,))
crop(in_fields, anchor, self.crop, self.in_pad, self.size) crop(in_fields, anchor, self.crop, self.in_pad)
crop(tgt_fields, anchor * self.scale_factor, crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor, self.tgt_pad)
self.tgt_pad,
self.size * self.scale_factor)
style = torch.from_numpy(style).to(torch.float32) style = torch.from_numpy(style).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -283,6 +281,7 @@ class FieldDataset(Dataset):
def fill(field, patch, anchor): def fill(field, patch, anchor):
ndim = len(anchor) ndim = len(anchor)
assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch'
ind = [slice(None)] ind = [slice(None)]
for d, (p, a, s) in enumerate(zip( for d, (p, a, s) in enumerate(zip(
@ -296,9 +295,11 @@ def fill(field, patch, anchor):
field[ind] = patch field[ind] = patch
def crop(fields, anchor, crop, pad, size): def crop(fields, anchor, crop, pad):
assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch'
size = fields[0].shape[1:]
ndim = len(size) ndim = len(size)
assert all(len(x) == ndim for x in [anchor, crop, pad]), 'ndim mismatch' assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch'
ind = [slice(None)] ind = [slice(None)]
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):