Remove input upsampling

This commit is contained in:
Yin Li 2020-04-30 21:11:43 -04:00
parent a88f27a3a1
commit b50c7b6350

View File

@ -23,17 +23,18 @@ class FieldDataset(Dataset):
Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
Input can be upsampled by `scale_factor` for super-resolution purpose,
in which case `crop`, `pad`, and other spatial attributes will be taken
at the original resolution.
Setting integer `scale_factor` greater than 1 will crop target bigger than
the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution.
Noise channels can be concatenated to the input.
`cache` enables data caching.
`div_data` enables data division, useful when combined with caching.
"""
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1,
cache=False, div_data=False, rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -89,10 +90,6 @@ class FieldDataset(Dataset):
"only support integer upsampling"
self.scale_factor = scale_factor
assert isinstance(noise_chan, int) and noise_chan >= 0, \
"only support integer noise channels"
self.noise_chan = noise_chan
self.cache = cache
if self.cache:
self.in_fields = {}
@ -118,9 +115,10 @@ class FieldDataset(Dataset):
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor,
self.crop * self.scale_factor, np.zeros_like(self.pad))
self.crop * self.scale_factor,
np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
@ -155,25 +153,13 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields
def crop(fields, start, crop, pad, scale_factor=1):
def crop(fields, start, crop, pad):
new_fields = []
for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
begin, end = i - p0, i + N + p1
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
begin, end = i - p0, i + c + p1
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
if scale_factor > 1:
x = torch.from_numpy(x).unsqueeze(0)
x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
x = x.numpy().squeeze(0)
# remove excess padding
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
begin = (scale_factor - 1) * p0
end = scale_factor * (N + p0) + p1
x = x.take(range(begin, end), axis=1 + d)
new_fields.append(x)
return new_fields