Add super-resolution
This commit is contained in:
parent
ba42bc6a55
commit
84a369d4ed
3 changed files with 80 additions and 6 deletions
|
@ -25,6 +25,9 @@ def add_common_args(parser):
|
|||
parser.add_argument('--pad', default=0, type=int,
|
||||
help='size to pad the input data beyond the crop size, assuming '
|
||||
'periodic boundary condition')
|
||||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='input upsampling factor for super-resolution purpose, in '
|
||||
'which case crop and pad will be taken at the original resolution')
|
||||
|
||||
parser.add_argument('--model', required=True, type=str,
|
||||
help='model from .models')
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from glob import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .norms import import_norm
|
||||
|
@ -22,11 +23,15 @@ class FieldDataset(Dataset):
|
|||
Input and target fields can be cropped.
|
||||
Input fields can be padded assuming periodic boundary condition.
|
||||
|
||||
Input can be upsampled by `scale_factor` for super-resolution purpose,
|
||||
in which case `crop`, `pad`, and other spatial attributes will be taken
|
||||
at the original resolution.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0,
|
||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1,
|
||||
cache=False, div_data=False, rank=None, world_size=None,
|
||||
**kwargs):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
|
@ -79,6 +84,10 @@ class FieldDataset(Dataset):
|
|||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||
|
||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||
"only support integer upsampling"
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = {}
|
||||
|
@ -104,7 +113,7 @@ class FieldDataset(Dataset):
|
|||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
|
||||
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
|
||||
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
|
@ -131,16 +140,26 @@ class FieldDataset(Dataset):
|
|||
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
#print(in_fields.shape, tgt_fields.shape)
|
||||
|
||||
return in_fields, tgt_fields
|
||||
|
||||
|
||||
def crop(fields, start, crop, pad):
|
||||
def crop(fields, start, crop, pad, scale_factor=1):
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
|
||||
start, stop = i - p0, i + N + p1
|
||||
# add buffer for linear interpolation
|
||||
if scale_factor > 1:
|
||||
start, stop = start - 1, stop + 1
|
||||
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
|
||||
|
||||
if scale_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
||||
# remove buffer
|
||||
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
||||
start, stop = scale_factor, N + p0 + p1 - scale_factor
|
||||
x = x.take(range(start, stop), axis=1 + d)
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue