Add super-resolution
This commit is contained in:
parent
ba42bc6a55
commit
84a369d4ed
@ -25,6 +25,9 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--pad', default=0, type=int,
|
parser.add_argument('--pad', default=0, type=int,
|
||||||
help='size to pad the input data beyond the crop size, assuming '
|
help='size to pad the input data beyond the crop size, assuming '
|
||||||
'periodic boundary condition')
|
'periodic boundary condition')
|
||||||
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
|
help='input upsampling factor for super-resolution purpose, in '
|
||||||
|
'which case crop and pad will be taken at the original resolution')
|
||||||
|
|
||||||
parser.add_argument('--model', required=True, type=str,
|
parser.add_argument('--model', required=True, type=str,
|
||||||
help='model from .models')
|
help='model from .models')
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from .norms import import_norm
|
from .norms import import_norm
|
||||||
@ -22,11 +23,15 @@ class FieldDataset(Dataset):
|
|||||||
Input and target fields can be cropped.
|
Input and target fields can be cropped.
|
||||||
Input fields can be padded assuming periodic boundary condition.
|
Input fields can be padded assuming periodic boundary condition.
|
||||||
|
|
||||||
|
Input can be upsampled by `scale_factor` for super-resolution purpose,
|
||||||
|
in which case `crop`, `pad`, and other spatial attributes will be taken
|
||||||
|
at the original resolution.
|
||||||
|
|
||||||
`cache` enables data caching.
|
`cache` enables data caching.
|
||||||
`div_data` enables data division, useful when combined with caching.
|
`div_data` enables data division, useful when combined with caching.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||||
in_norms=None, tgt_norms=None, augment=False, crop=None, pad=0,
|
augment=False, crop=None, pad=0, scale_factor=1,
|
||||||
cache=False, div_data=False, rank=None, world_size=None,
|
cache=False, div_data=False, rank=None, world_size=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
@ -79,6 +84,10 @@ class FieldDataset(Dataset):
|
|||||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||||
|
|
||||||
|
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||||
|
"only support integer upsampling"
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.in_fields = {}
|
self.in_fields = {}
|
||||||
@ -104,7 +113,7 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
|
|
||||||
in_fields = crop(in_fields, start, self.crop, self.pad)
|
in_fields = crop(in_fields, start, self.crop, self.pad, self.scale_factor)
|
||||||
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
|
tgt_fields = crop(tgt_fields, start, self.crop, np.zeros_like(self.pad))
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
@ -131,16 +140,26 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
#print(in_fields.shape, tgt_fields.shape)
|
|
||||||
|
|
||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, start, crop, pad):
|
def crop(fields, start, crop, pad, scale_factor=1):
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||||
x = x.take(range(i - p0, i + N + p1), axis=1 + d, mode='wrap')
|
start, stop = i - p0, i + N + p1
|
||||||
|
# add buffer for linear interpolation
|
||||||
|
if scale_factor > 1:
|
||||||
|
start, stop = start - 1, stop + 1
|
||||||
|
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
|
||||||
|
|
||||||
|
if scale_factor > 1:
|
||||||
|
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
||||||
|
# remove buffer
|
||||||
|
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
||||||
|
start, stop = scale_factor, N + p0 + p1 - scale_factor
|
||||||
|
x = x.take(range(start, stop), axis=1 + d)
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
|
52
scripts/srsgan.slurm
Normal file
52
scripts/srsgan.slurm
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#SBATCH --job-name=srsgan
|
||||||
|
#SBATCH --output=%x-%j.out
|
||||||
|
|
||||||
|
#SBATCH --partition=rtx
|
||||||
|
#SBATCH --gres=gpu:4
|
||||||
|
|
||||||
|
#SBATCH --exclusive
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --time=7-00:00:00
|
||||||
|
|
||||||
|
|
||||||
|
hostname; pwd; date
|
||||||
|
|
||||||
|
|
||||||
|
#module load gcc python3
|
||||||
|
source $HOME/anaconda/bin/activate torch
|
||||||
|
|
||||||
|
|
||||||
|
export MASTER_ADDR=$HOSTNAME
|
||||||
|
export MASTER_PORT=60606
|
||||||
|
|
||||||
|
|
||||||
|
data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train"
|
||||||
|
|
||||||
|
in_dir="low-resl"
|
||||||
|
tgt_dir="high-resl"
|
||||||
|
|
||||||
|
train_dirs="set[0-3]/output/PART_004"
|
||||||
|
val_dirs="set4/output/PART_004"
|
||||||
|
|
||||||
|
in_files_1="disp.npy"
|
||||||
|
in_files_2="vel.npy"
|
||||||
|
tgt_files_1="disp.npy"
|
||||||
|
tgt_files_2="vel.npy"
|
||||||
|
|
||||||
|
|
||||||
|
srun m2m.py train \
|
||||||
|
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \
|
||||||
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \
|
||||||
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
||||||
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
||||||
|
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
||||||
|
--model VNet --adv-model PatchGAN --cgan \
|
||||||
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
|
--epochs 128 --seed $RANDOM \
|
||||||
|
--cache --div-data
|
||||||
|
# --load-state checkpoint.pth \
|
||||||
|
|
||||||
|
|
||||||
|
date
|
Loading…
Reference in New Issue
Block a user