Remove input upsampling
This commit is contained in:
parent
a88f27a3a1
commit
b50c7b6350
@ -23,17 +23,18 @@ class FieldDataset(Dataset):
|
|||||||
Input and target fields can be cropped.
|
Input and target fields can be cropped.
|
||||||
Input fields can be padded assuming periodic boundary condition.
|
Input fields can be padded assuming periodic boundary condition.
|
||||||
|
|
||||||
Input can be upsampled by `scale_factor` for super-resolution purpose,
|
Setting integer `scale_factor` greater than 1 will crop target bigger than
|
||||||
in which case `crop`, `pad`, and other spatial attributes will be taken
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
at the original resolution.
|
the input resolution.
|
||||||
|
|
||||||
Noise channels can be concatenated to the input.
|
Noise channels can be concatenated to the input.
|
||||||
|
|
||||||
`cache` enables data caching.
|
`cache` enables data caching.
|
||||||
`div_data` enables data division, useful when combined with caching.
|
`div_data` enables data division, useful when combined with caching.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
in_norms=None, tgt_norms=None,
|
||||||
|
augment=False, crop=None, pad=0, scale_factor=1,
|
||||||
cache=False, div_data=False, rank=None, world_size=None):
|
cache=False, div_data=False, rank=None, world_size=None):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
@ -89,10 +90,6 @@ class FieldDataset(Dataset):
|
|||||||
"only support integer upsampling"
|
"only support integer upsampling"
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
assert isinstance(noise_chan, int) and noise_chan >= 0, \
|
|
||||||
"only support integer noise channels"
|
|
||||||
self.noise_chan = noise_chan
|
|
||||||
|
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.in_fields = {}
|
self.in_fields = {}
|
||||||
@ -118,9 +115,10 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
|
|
||||||
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
|
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||||
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
tgt_fields = crop(tgt_fields, start * self.scale_factor,
|
||||||
self.crop * self.scale_factor, np.zeros_like(self.pad))
|
self.crop * self.scale_factor,
|
||||||
|
np.zeros_like(self.pad))
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
@ -155,25 +153,13 @@ class FieldDataset(Dataset):
|
|||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, start, crop, pad, scale_factor=1):
|
def crop(fields, start, crop, pad):
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
for d, (i, c, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||||
begin, end = i - p0, i + N + p1
|
begin, end = i - p0, i + c + p1
|
||||||
|
|
||||||
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
||||||
|
|
||||||
if scale_factor > 1:
|
|
||||||
x = torch.from_numpy(x).unsqueeze(0)
|
|
||||||
x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
|
|
||||||
x = x.numpy().squeeze(0)
|
|
||||||
|
|
||||||
# remove excess padding
|
|
||||||
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
|
||||||
begin = (scale_factor - 1) * p0
|
|
||||||
end = scale_factor * (N + p0) + p1
|
|
||||||
x = x.take(range(begin, end), axis=1 + d)
|
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
return new_fields
|
return new_fields
|
||||||
|
Loading…
Reference in New Issue
Block a user