Add super-resolution

This commit is contained in:
Yin Li 2020-01-22 19:47:27 -05:00
parent ba42bc6a55
commit 84a369d4ed
3 changed files with 80 additions and 6 deletions

View file

@ -25,6 +25,9 @@ def add_common_args(parser):
parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution')
parser.add_argument('--model', required=True, type=str,
help='model from .models')

View file

@ -1,6 +1,7 @@
from glob import glob
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from .norms import import_norm
@ -22,11 +23,15 @@ class FieldDataset(Dataset):
Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
Input can be upsampled by `scale_factor` for super-resolution purpose,
in which case `crop`, `pad`, and other spatial attributes will be taken
at the original resolution.
`cache` enables data caching.
`div_data` enables data division, useful when combined with caching.
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0,
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1,
cache=False, div_data=False, rank=None, world_size=None,
**kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -79,6 +84,10 @@ class FieldDataset(Dataset):
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
assert isinstance(scale_factor, int) and scale_factor >= 1, \
"only support integer upsampling"
self.scale_factor = scale_factor
self.cache = cache
if self.cache:
self.in_fields = {}
@ -104,7 +113,7 @@ class FieldDataset(Dataset):
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields = crop(in_fields, start, self.crop, self.pad)
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -131,16 +140,26 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
#print(in_fields.shape, tgt_fields.shape)
return in_fields, tgt_fields
def crop(fields, start, crop, pad):
def crop(fields, start, crop, pad, scale_factor=1):
new_fields = []
for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
start, stop = i - p0, i + N + p1
# add buffer for linear interpolation
if scale_factor > 1:
start, stop = start - 1, stop + 1
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
if scale_factor > 1:
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
# remove buffer
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
start, stop = scale_factor, N + p0 + p1 - scale_factor
x = x.take(range(start, stop), axis=1 + d)
new_fields.append(x)