Add super-resolution

This commit is contained in:
Yin Li 2020-01-22 19:47:27 -05:00
parent ba42bc6a55
commit 84a369d4ed
3 changed files with 80 additions and 6 deletions

View File

@ -25,6 +25,9 @@ def add_common_args(parser):
parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution')
parser.add_argument('--model', required=True, type=str,
help='model from .models')

View File

@ -1,6 +1,7 @@
from glob import glob
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from .norms import import_norm
@ -22,11 +23,15 @@ class FieldDataset(Dataset):
Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition.
Input can be upsampled by `scale_factor` for super-resolution purpose,
in which case `crop`, `pad`, and other spatial attributes will be taken
at the original resolution.
`cache` enables data caching.
`div_data` enables data division, useful when combined with caching.
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0,
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1,
cache=False, div_data=False, rank=None, world_size=None,
**kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -79,6 +84,10 @@ class FieldDataset(Dataset):
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
assert isinstance(scale_factor, int) and scale_factor >= 1, \
"only support integer upsampling"
self.scale_factor = scale_factor
self.cache = cache
if self.cache:
self.in_fields = {}
@ -104,7 +113,7 @@ class FieldDataset(Dataset):
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
in_fields = crop(in_fields, start, self.crop, self.pad)
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -131,16 +140,26 @@ class FieldDataset(Dataset):
in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0)
#print(in_fields.shape, tgt_fields.shape)
return in_fields, tgt_fields
def crop(fields, start, crop, pad):
def crop(fields, start, crop, pad, scale_factor=1):
new_fields = []
for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
start, stop = i - p0, i + N + p1
# add buffer for linear interpolation
if scale_factor > 1:
start, stop = start - 1, stop + 1
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
if scale_factor > 1:
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
# remove buffer
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
start, stop = scale_factor, N + p0 + p1 - scale_factor
x = x.take(range(start, stop), axis=1 + d)
new_fields.append(x)

52
scripts/srsgan.slurm Normal file
View File

@ -0,0 +1,52 @@
#!/bin/bash
#SBATCH --job-name=srsgan
#SBATCH --output=%x-%j.out
#SBATCH --partition=rtx
#SBATCH --gres=gpu:4
#SBATCH --exclusive
#SBATCH --nodes=1
#SBATCH --time=7-00:00:00
hostname; pwd; date
#module load gcc python3
source $HOME/anaconda/bin/activate torch
export MASTER_ADDR=$HOSTNAME
export MASTER_PORT=60606
data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train"
in_dir="low-resl"
tgt_dir="high-resl"
train_dirs="set[0-3]/output/PART_004"
val_dirs="set4/output/PART_004"
in_files_1="disp.npy"
in_files_2="vel.npy"
tgt_files_1="disp.npy"
tgt_files_2="vel.npy"
srun m2m.py train \
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
--model VNet --adv-model PatchGAN --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 128 --seed $RANDOM \
--cache --div-data
# --load-state checkpoint.pth \
date