From 84a369d4ed9012c2efd82c8c5e7fa3848dea09d2 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 22 Jan 2020 19:47:27 -0500 Subject: [PATCH] Add super-resolution --- map2map/args.py | 3 +++ map2map/data/fields.py | 31 ++++++++++++++++++++----- scripts/srsgan.slurm | 52 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 6 deletions(-) create mode 100644 scripts/srsgan.slurm diff --git a/map2map/args.py b/map2map/args.py index 05cbd6f..31c1625 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -25,6 +25,9 @@ def add_common_args(parser): parser.add_argument('--pad', default=0, type=int, help='size to pad the input data beyond the crop size, assuming ' 'periodic boundary condition') + parser.add_argument('--scale-factor', default=1, type=int, + help='input upsampling factor for super-resolution purpose, in ' + 'which case crop and pad will be taken at the original resolution') parser.add_argument('--model', required=True, type=str, help='model from .models') diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 2c92792..e63dc9f 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -1,6 +1,7 @@ from glob import glob import numpy as np import torch +import torch.nn.functional as F from torch.utils.data import Dataset from .norms import import_norm @@ -22,11 +23,15 @@ class FieldDataset(Dataset): Input and target fields can be cropped. Input fields can be padded assuming periodic boundary condition. + Input can be upsampled by `scale_factor` for super-resolution purpose, + in which case `crop`, `pad`, and other spatial attributes will be taken + at the original resolution. + `cache` enables data caching. `div_data` enables data division, useful when combined with caching. """ - def __init__(self, in_patterns, tgt_patterns, - in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0, + def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, + augment=False, crop=None, pad=0, scale_factor=1, cache=False, div_data=False, rank=None, world_size=None, **kwargs): in_file_lists = [sorted(glob(p)) for p in in_patterns] @@ -79,6 +84,10 @@ class FieldDataset(Dataset): assert isinstance(pad, int), 'only support symmetric padding for now' self.pad = np.broadcast_to(pad, (self.ndim, 2)) + assert isinstance(scale_factor, int) and scale_factor >= 1, \ + "only support integer upsampling" + self.scale_factor = scale_factor + self.cache = cache if self.cache: self.in_fields = {} @@ -104,7 +113,7 @@ class FieldDataset(Dataset): in_fields = [np.load(f) for f in self.in_files[idx]] tgt_fields = [np.load(f) for f in self.tgt_files[idx]] - in_fields = crop(in_fields, start, self.crop, self.pad) + in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor) tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad)) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] @@ -131,16 +140,26 @@ class FieldDataset(Dataset): in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) - #print(in_fields.shape, tgt_fields.shape) return in_fields, tgt_fields -def crop(fields, start, crop, pad): +def crop(fields, start, crop, pad, scale_factor=1): new_fields = [] for x in fields: for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): - x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap') + start, stop = i - p0, i + N + p1 + # add buffer for linear interpolation + if scale_factor > 1: + start, stop = start - 1, stop + 1 + x = x.take(range(start, stop), axis=1 + d, mode='wrap') + + if scale_factor > 1: + x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') + # remove buffer + for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): + start, stop = scale_factor, N + p0 + p1 - scale_factor + x = x.take(range(start, stop), axis=1 + d) new_fields.append(x) diff --git a/scripts/srsgan.slurm b/scripts/srsgan.slurm new file mode 100644 index 0000000..0d9e265 --- /dev/null +++ b/scripts/srsgan.slurm @@ -0,0 +1,52 @@ +#!/bin/bash + +#SBATCH --job-name=srsgan +#SBATCH --output=%x-%j.out + +#SBATCH --partition=rtx +#SBATCH --gres=gpu:4 + +#SBATCH --exclusive +#SBATCH --nodes=1 +#SBATCH --time=7-00:00:00 + + +hostname; pwd; date + + +#module load gcc python3 +source $HOME/anaconda/bin/activate torch + + +export MASTER_ADDR=$HOSTNAME +export MASTER_PORT=60606 + + +data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train" + +in_dir="low-resl" +tgt_dir="high-resl" + +train_dirs="set[0-3]/output/PART_004" +val_dirs="set4/output/PART_004" + +in_files_1="disp.npy" +in_files_2="vel.npy" +tgt_files_1="disp.npy" +tgt_files_2="vel.npy" + + +srun m2m.py train \ + --train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \ + --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \ + --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \ + --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \ + --in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \ + --model VNet --adv-model PatchGAN --cgan \ + --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ + --epochs 128 --seed $RANDOM \ + --cache --div-data +# --load-state checkpoint.pth \ + + +date