Add super-resolution

This commit is contained in:
Yin Li 2020-01-22 19:47:27 -05:00
parent ba42bc6a55
commit 84a369d4ed
3 changed files with 80 additions and 6 deletions

View File

@ -25,6 +25,9 @@ def add_common_args(parser):
parser.add_argument('--pad', default=0, type=int, parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming ' help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition') 'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution')
parser.add_argument('--model', required=True, type=str, parser.add_argument('--model', required=True, type=str,
help='model from .models') help='model from .models')

View File

@ -1,6 +1,7 @@
from glob import glob from glob import glob
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .norms import import_norm from .norms import import_norm
@ -22,11 +23,15 @@ class FieldDataset(Dataset):
Input and target fields can be cropped. Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition. Input fields can be padded assuming periodic boundary condition.
Input can be upsampled by `scale_factor` for super-resolution purpose,
in which case `crop`, `pad`, and other spatial attributes will be taken
at the original resolution.
`cache` enables data caching. `cache` enables data caching.
`div_data` enables data division, useful when combined with caching. `div_data` enables data division, useful when combined with caching.
""" """
def __init__(self, in_patterns, tgt_patterns, def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0, augment=False, crop=None, pad=0, scale_factor=1,
cache=False, div_data=False, rank=None, world_size=None, cache=False, div_data=False, rank=None, world_size=None,
**kwargs): **kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -79,6 +84,10 @@ class FieldDataset(Dataset):
assert isinstance(pad, int), 'only support symmetric padding for now' assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2)) self.pad = np.broadcast_to(pad, (self.ndim, 2))
assert isinstance(scale_factor, int) and scale_factor >= 1, \
"only support integer upsampling"
self.scale_factor = scale_factor
self.cache = cache self.cache = cache
if self.cache: if self.cache:
self.in_fields = {} self.in_fields = {}
@ -104,7 +113,7 @@ class FieldDataset(Dataset):
in_fields = [np.load(f) for f in self.in_files[idx]] in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]] tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields = crop(in_fields, start, self.crop, self.pad) in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad)) tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -131,16 +140,26 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)
#print(in_fields.shape, tgt_fields.shape)
return in_fields, tgt_fields return in_fields, tgt_fields
def crop(fields, start, crop, pad): def crop(fields, start, crop, pad, scale_factor=1):
new_fields = [] new_fields = []
for x in fields: for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap') start, stop = i - p0, i + N + p1
# add buffer for linear interpolation
if scale_factor > 1:
start, stop = start - 1, stop + 1
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
if scale_factor > 1:
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
# remove buffer
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
start, stop = scale_factor, N + p0 + p1 - scale_factor
x = x.take(range(start, stop), axis=1 + d)
new_fields.append(x) new_fields.append(x)

52
scripts/srsgan.slurm Normal file
View File

@ -0,0 +1,52 @@
#!/bin/bash
#SBATCH --job-name=srsgan
#SBATCH --output=%x-%j.out
#SBATCH --partition=rtx
#SBATCH --gres=gpu:4
#SBATCH --exclusive
#SBATCH --nodes=1
#SBATCH --time=7-00:00:00
hostname; pwd; date
#module load gcc python3
source $HOME/anaconda/bin/activate torch
export MASTER_ADDR=$HOSTNAME
export MASTER_PORT=60606
data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train"
in_dir="low-resl"
tgt_dir="high-resl"
train_dirs="set[0-3]/output/PART_004"
val_dirs="set4/output/PART_004"
in_files_1="disp.npy"
in_files_2="vel.npy"
tgt_files_1="disp.npy"
tgt_files_2="vel.npy"
srun m2m.py train \
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
--model VNet --adv-model PatchGAN --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 128 --seed $RANDOM \
--cache --div-data
# --load-state checkpoint.pth \
date