Add super-resolution
This commit is contained in:
parent
ba42bc6a55
commit
84a369d4ed
@ -25,6 +25,9 @@ def add_common_args(parser):
|
||||
parser.add_argument('--pad', default=0, type=int,
|
||||
help='size to pad the input data beyond the crop size, assuming '
|
||||
'periodic boundary condition')
|
||||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='input upsampling factor for super-resolution purpose, in '
|
||||
'which case crop and pad will be taken at the original resolution')
|
||||
|
||||
parser.add_argument('--model', required=True, type=str,
|
||||
help='model from .models')
|
||||
|
@ -1,6 +1,7 @@
|
||||
from glob import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .norms import import_norm
|
||||
@ -22,11 +23,15 @@ class FieldDataset(Dataset):
|
||||
Input and target fields can be cropped.
|
||||
Input fields can be padded assuming periodic boundary condition.
|
||||
|
||||
Input can be upsampled by `scale_factor` for super-resolution purpose,
|
||||
in which case `crop`, `pad`, and other spatial attributes will be taken
|
||||
at the original resolution.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0,
|
||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1,
|
||||
cache=False, div_data=False, rank=None, world_size=None,
|
||||
**kwargs):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
@ -79,6 +84,10 @@ class FieldDataset(Dataset):
|
||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||
|
||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||
"only support integer upsampling"
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = {}
|
||||
@ -104,7 +113,7 @@ class FieldDataset(Dataset):
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
||||
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
|
||||
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
|
||||
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
@ -131,16 +140,26 @@ class FieldDataset(Dataset):
|
||||
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
#print(in_fields.shape, tgt_fields.shape)
|
||||
|
||||
return in_fields, tgt_fields
|
||||
|
||||
|
||||
def crop(fields, start, crop, pad):
|
||||
def crop(fields, start, crop, pad, scale_factor=1):
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
|
||||
start, stop = i - p0, i + N + p1
|
||||
# add buffer for linear interpolation
|
||||
if scale_factor > 1:
|
||||
start, stop = start - 1, stop + 1
|
||||
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
|
||||
|
||||
if scale_factor > 1:
|
||||
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
||||
# remove buffer
|
||||
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
||||
start, stop = scale_factor, N + p0 + p1 - scale_factor
|
||||
x = x.take(range(start, stop), axis=1 + d)
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
|
52
scripts/srsgan.slurm
Normal file
52
scripts/srsgan.slurm
Normal file
@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --job-name=srsgan
|
||||
#SBATCH --output=%x-%j.out
|
||||
|
||||
#SBATCH --partition=rtx
|
||||
#SBATCH --gres=gpu:4
|
||||
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --time=7-00:00:00
|
||||
|
||||
|
||||
hostname; pwd; date
|
||||
|
||||
|
||||
#module load gcc python3
|
||||
source $HOME/anaconda/bin/activate torch
|
||||
|
||||
|
||||
export MASTER_ADDR=$HOSTNAME
|
||||
export MASTER_PORT=60606
|
||||
|
||||
|
||||
data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train"
|
||||
|
||||
in_dir="low-resl"
|
||||
tgt_dir="high-resl"
|
||||
|
||||
train_dirs="set[0-3]/output/PART_004"
|
||||
val_dirs="set4/output/PART_004"
|
||||
|
||||
in_files_1="disp.npy"
|
||||
in_files_2="vel.npy"
|
||||
tgt_files_1="disp.npy"
|
||||
tgt_files_2="vel.npy"
|
||||
|
||||
|
||||
srun m2m.py train \
|
||||
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \
|
||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
||||
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
||||
--model VNet --adv-model PatchGAN --cgan \
|
||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||
--epochs 128 --seed $RANDOM \
|
||||
--cache --div-data
|
||||
# --load-state checkpoint.pth \
|
||||
|
||||
|
||||
date
|
Loading…
Reference in New Issue
Block a user