Remove input upsampling

This commit is contained in:
Yin Li 2020-04-30 21:11:43 -04:00
parent a88f27a3a1
commit b50c7b6350

View File

@ -23,17 +23,18 @@ class FieldDataset(Dataset):
Input and target fields can be cropped. Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition. Input fields can be padded assuming periodic boundary condition.
Input can be upsampled by `scale_factor` for super-resolution purpose, Setting integer `scale_factor` greater than 1 will crop target bigger than
in which case `crop`, `pad`, and other spatial attributes will be taken the input for super-resolution, in which case `crop` and `pad` are sizes of
at the original resolution. the input resolution.
Noise channels can be concatenated to the input. Noise channels can be concatenated to the input.
`cache` enables data caching. `cache` enables data caching.
`div_data` enables data division, useful when combined with caching. `div_data` enables data division, useful when combined with caching.
""" """
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, def __init__(self, in_patterns, tgt_patterns,
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1,
cache=False, div_data=False, rank=None, world_size=None): cache=False, div_data=False, rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))
@ -89,10 +90,6 @@ class FieldDataset(Dataset):
"only support integer upsampling" "only support integer upsampling"
self.scale_factor = scale_factor self.scale_factor = scale_factor
assert isinstance(noise_chan, int) and noise_chan >= 0, \
"only support integer noise channels"
self.noise_chan = noise_chan
self.cache = cache self.cache = cache
if self.cache: if self.cache:
self.in_fields = {} self.in_fields = {}
@ -118,9 +115,10 @@ class FieldDataset(Dataset):
in_fields = [np.load(f) for f in self.in_files[idx]] in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]] tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor) in_fields = crop(in_fields, start, self.crop, self.pad)
tgt_fields = crop(tgt_fields, start * self.scale_factor, tgt_fields = crop(tgt_fields, start * self.scale_factor,
self.crop * self.scale_factor, np.zeros_like(self.pad)) self.crop * self.scale_factor,
np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
@ -155,25 +153,13 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields return in_fields, tgt_fields
def crop(fields, start, crop, pad, scale_factor=1): def crop(fields, start, crop, pad):
new_fields = [] new_fields = []
for x in fields: for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
begin, end = i - p0, i + N + p1 begin, end = i - p0, i + c + p1
x = x.take(range(begin, end), axis=1 + d, mode='wrap') x = x.take(range(begin, end), axis=1 + d, mode='wrap')
if scale_factor > 1:
x = torch.from_numpy(x).unsqueeze(0)
x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
x = x.numpy().squeeze(0)
# remove excess padding
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
begin = (scale_factor - 1) * p0
end = scale_factor * (N + p0) + p1
x = x.take(range(begin, end), axis=1 + d)
new_fields.append(x) new_fields.append(x)
return new_fields return new_fields