Remove input upsampling
This commit is contained in:
parent
a88f27a3a1
commit
b50c7b6350
@ -23,18 +23,19 @@ class FieldDataset(Dataset):
|
||||
Input and target fields can be cropped.
|
||||
Input fields can be padded assuming periodic boundary condition.
|
||||
|
||||
Input can be upsampled by `scale_factor` for super-resolution purpose,
|
||||
in which case `crop`, `pad`, and other spatial attributes will be taken
|
||||
at the original resolution.
|
||||
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||
the input resolution.
|
||||
|
||||
Noise channels can be concatenated to the input.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
||||
cache=False, div_data=False, rank=None, world_size=None):
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1,
|
||||
cache=False, div_data=False, rank=None, world_size=None):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
||||
@ -89,10 +90,6 @@ class FieldDataset(Dataset):
|
||||
"only support integer upsampling"
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert isinstance(noise_chan, int) and noise_chan >= 0, \
|
||||
"only support integer noise channels"
|
||||
self.noise_chan = noise_chan
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = {}
|
||||
@ -118,9 +115,10 @@ class FieldDataset(Dataset):
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
||||
self.crop * self.scale_factor, np.zeros_like(self.pad))
|
||||
self.crop * self.scale_factor,
|
||||
np.zeros_like(self.pad))
|
||||
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||
@ -155,25 +153,13 @@ class FieldDataset(Dataset):
|
||||
return in_fields, tgt_fields
|
||||
|
||||
|
||||
def crop(fields, start, crop, pad, scale_factor=1):
|
||||
def crop(fields, start, crop, pad):
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||
begin, end = i - p0, i + N + p1
|
||||
|
||||
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||
begin, end = i - p0, i + c + p1
|
||||
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
||||
|
||||
if scale_factor > 1:
|
||||
x = torch.from_numpy(x).unsqueeze(0)
|
||||
x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
|
||||
x = x.numpy().squeeze(0)
|
||||
|
||||
# remove excess padding
|
||||
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
||||
begin = (scale_factor - 1) * p0
|
||||
end = scale_factor * (N + p0) + p1
|
||||
x = x.take(range(begin, end), axis=1 + d)
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
return new_fields
|
||||
|
Loading…
Reference in New Issue
Block a user