Add testing

This commit is contained in:
Yin Li 2019-12-01 18:53:38 -05:00
parent bcf95275f3
commit 0211eed0ec
11 changed files with 252 additions and 73 deletions

View File

@ -21,11 +21,22 @@ def add_common_args(parser):
parser.add_argument('--out-channels', type=int, required=True, parser.add_argument('--out-channels', type=int, required=True,
help='number of output or target channels') help='number of output or target channels')
parser.add_argument('--norms', type=str_list, help='comma-sep. list ' parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from map2map.data.norms') 'of normalization functions from data.norms')
parser.add_argument('--criterion', default='MSELoss', parser.add_argument('--criterion', default='MSELoss',
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str, parser.add_argument('--load-state', default='', type=str,
help='path to load model, optimizer, rng, etc.') help='path to load model, optimizer, rng, etc.')
parser.add_argument('--batches', default=1, type=int,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=0, type=int,
help='number of data loading workers, per GPU in training or '
'in total in testing')
parser.add_argument('--pad-or-crop', default=0, type=int_tuple,
help='pad (>0) or crop (<0) the input data; '
'can be a int or a 6-tuple (by a comma-sep. list); '
'can be asymmetric to align the data with downsample '
'and upsample convolutions; '
'padding assumes periodic boundary condition')
def add_train_args(parser): def add_train_args(parser):
@ -39,12 +50,8 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation input data') help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list, required=True, parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for validation target data') help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--epochs', default=128, type=int, parser.add_argument('--epochs', default=1024, type=int,
help='total number of epochs to run') help='total number of epochs to run')
parser.add_argument('--batches-per-gpu', default=8, type=int,
help='mini-batch size per GPU')
parser.add_argument('--loader-workers-per-gpu', default=4, type=int,
help='number of data loading workers per GPU')
parser.add_argument('--augment', action='store_true', parser.add_argument('--augment', action='store_true',
help='enable training data augmentation') help='enable training data augmentation')
parser.add_argument('--optimizer', default='Adam', parser.add_argument('--optimizer', default='Adam',
@ -74,3 +81,13 @@ def add_test_args(parser):
def str_list(s): def str_list(s):
return s.split(',') return s.split(',')
def int_tuple(t):
t = t.split(',')
t = tuple(int(i) for i in t)
if len(t) == 1:
t = t[0]
elif len(t) != 6:
raise ValueError('pad or crop size must be int or 6-tuple')
return t

View File

@ -3,7 +3,7 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from . import norms from .norms import import_norm
class FieldDataset(Dataset): class FieldDataset(Dataset):
@ -14,12 +14,15 @@ class FieldDataset(Dataset):
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target samples of all fields are matched by sorting the globbed files. Input and target samples of all fields are matched by sorting the globbed files.
Input fields can be padded (>0) or cropped (<0) with `pad_or_crop`.
Padding assumes periodic boundary condition.
Data augmentations are supported for scalar and vector fields. Data augmentations are supported for scalar and vector fields.
`normalize` can be a list of callables to normalize each field. `norms` can be a list of callables to normalize each field.
""" """
def __init__(self, in_patterns, tgt_patterns, augment=False, def __init__(self, in_patterns, tgt_patterns, pad_or_crop=0, augment=False,
normalize=None, **kwargs): norms=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))
@ -29,23 +32,31 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \ assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match' 'input and target sample sizes do not match'
if isinstance(pad_or_crop, int):
pad_or_crop = (pad_or_crop,) * 6
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
'pad or crop size must be int or 6-tuple'
self.pad_or_crop = np.array((0,) * 2 + pad_or_crop).reshape(4, 2)
self.augment = augment self.augment = augment
self.normalize = normalize if norms is not None:
if self.normalize is not None: assert len(in_patterns) == len(norms), \
assert len(in_patterns) == len(self.normalize), \
'numbers of normalization callables and input fields do not match' 'numbers of normalization callables and input fields do not match'
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
# self.__dict__.update(kwargs) self.norms = norms
def __len__(self): def __len__(self):
return len(self.in_files) return len(self.in_files)
def __getitem__(self, idx): def __getitem__(self, idx):
in_fields = [torch.from_numpy(np.load(f)).to(torch.float32) in_fields = [np.load(f) for f in self.in_files[idx]]
for f in self.in_files[idx]] tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
tgt_fields = [torch.from_numpy(np.load(f)).to(torch.float32)
for f in self.tgt_files[idx]] padcrop(in_fields, self.pad_or_crop) # with numpy
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
if self.augment: if self.augment:
flip_axes = torch.randint(2, (3,), dtype=torch.bool) flip_axes = torch.randint(2, (3,), dtype=torch.bool)
@ -59,18 +70,8 @@ class FieldDataset(Dataset):
perm3d(in_fields, perm_axes) perm3d(in_fields, perm_axes)
perm3d(tgt_fields, perm_axes) perm3d(tgt_fields, perm_axes)
if self.normalize is not None: if self.norms is not None:
def get_norm(path): for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields):
path = path.split('.')
norm = norms
while path:
norm = norm.__dict__[path.pop(0)]
return norm
for norm, ifield, tfield in zip(self.normalize, in_fields, tgt_fields):
if isinstance(norm, str):
norm = get_norm(norm)
norm(ifield) norm(ifield)
norm(tfield) norm(tfield)
@ -80,6 +81,22 @@ class FieldDataset(Dataset):
return in_fields, tgt_fields return in_fields, tgt_fields
def padcrop(fields, width):
for i, x in enumerate(fields):
if (width >= 0).all():
x = np.pad(x, width, mode='wrap')
elif (width <= 0).all():
x = x[...,
-width[1, 0] : width[1, 1],
-width[2, 0] : width[2, 1],
-width[3, 0] : width[3, 1],
]
else:
raise NotImplementedError('mixed pad-and-crop not supported')
fields[i] = x
def flip3d(fields, axes): def flip3d(fields, axes):
for i, x in enumerate(fields): for i, x in enumerate(fields):
if x.size(0) == 3: # flip vector components if x.size(0) == 3: # flip vector components
@ -90,6 +107,7 @@ def flip3d(fields, axes):
fields[i] = x fields[i] = x
def perm3d(fields, axes): def perm3d(fields, axes):
for i, x in enumerate(fields): for i, x in enumerate(fields):
if x.size(0) == 3: # permutate vector components if x.size(0) == 3: # permutate vector components

View File

@ -1 +1,10 @@
from importlib import import_module
from . import cosmology from . import cosmology
def import_norm(path):
mod, func = path.rsplit('.', 1)
mod = import_module('.' + mod, __name__)
func = getattr(mod, func)
return func

View File

@ -10,4 +10,4 @@ def main():
if args.mode == 'train': if args.mode == 'train':
train.node_worker(args) train.node_worker(args)
elif args.mode == 'test': elif args.mode == 'test':
pass test.test(args)

View File

@ -4,15 +4,6 @@ import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like from .conv import ConvBlock, ResBlock, narrow_like
class DownBlock(ConvBlock):
def __init__(self, in_channels, out_channels, seq='BADBA'):
super().__init__(in_channels, out_channels, seq=seq)
class UpBlock(ConvBlock):
def __init__(self, in_channels, out_channels, seq='BAUBA'):
super().__init__(in_channels, out_channels, seq=seq)
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super().__init__() super().__init__()

View File

@ -1,8 +1,58 @@
import os import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset from .data import FieldDataset
from .models import UNet, narrow_like from .models import UNet, narrow_like
def test(args):
test_dataset = FieldDataset(
in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns,
augment=False,
norms=args.norms,
pad_or_crop=args.pad_or_crop,
)
test_loader = DataLoader(
test_dataset,
batch_size=args.batches,
shuffle=False,
num_workers=args.loader_workers,
)
model = UNet(args.in_channels, args.out_channels)
criterion = torch.nn.__dict__[args.criterion]()
device = torch.device('cpu')
state = torch.load(args.load_state, map_location=device)
from collections import OrderedDict
model_state = OrderedDict()
for k, v in state['model'].items():
model_k = k.replace('module.', '', 1) # FIXME
model_state[model_k] = v
model.load_state_dict(model_state)
print('model state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state))
del state
model.eval()
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
output = model(input)
if args.pad_or_crop > 0: # FIXME
output = narrow_like(output, target)
else:
target = narrow_like(target, output)
loss = criterion(output, target)
print('sample {} loss: {}'.format(i, loss))
if args.norms is not None:
norm = test_dataset.norms[0] # FIXME
norm(output, undo=True)
np.savez('{}.npz'.format(i), input=input.numpy(),
output=output.numpy(), target=target.numpy())

View File

@ -1,6 +1,5 @@
import os import os
import shutil import shutil
import torch import torch
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
from torch.distributed import init_process_group, destroy_process_group, all_reduce from torch.distributed import init_process_group, destroy_process_group, all_reduce
@ -46,15 +45,16 @@ def gpu_worker(local_rank, args):
in_patterns=args.train_in_patterns, in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns, tgt_patterns=args.train_tgt_patterns,
augment=args.augment, augment=args.augment,
normalize=args.norms, norms=args.norms,
pad_or_crop=args.pad_or_crop,
) )
train_sampler = DistributedSampler(train_dataset, shuffle=True) train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batches_per_gpu, batch_size=args.batches,
shuffle=False, shuffle=False,
sampler=train_sampler, sampler=train_sampler,
num_workers=args.loader_workers_per_gpu, num_workers=args.loader_workers,
pin_memory=True pin_memory=True
) )
@ -62,15 +62,16 @@ def gpu_worker(local_rank, args):
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
augment=False, augment=False,
normalize=args.norms, norms=args.norms,
pad_or_crop=args.pad_or_crop,
) )
val_sampler = DistributedSampler(val_dataset, shuffle=False) val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batches_per_gpu, batch_size=args.batches,
shuffle=False, shuffle=False,
sampler=val_sampler, sampler=val_sampler,
num_workers=args.loader_workers_per_gpu, num_workers=args.loader_workers,
pin_memory=True pin_memory=True
) )
@ -90,17 +91,17 @@ def gpu_worker(local_rank, args):
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
if args.load_state: if args.load_state:
checkpoint = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=args.device)
args.start_epoch = checkpoint['epoch'] args.start_epoch = state['epoch']
model.load_state_dict(checkpoint['model']) model.load_state_dict(state['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler']) scheduler.load_state_dict(state['scheduler'])
torch.set_rng_state(checkpoint['rng'].cpu()) # move rng state back torch.set_rng_state(state['rng'].cpu()) # move rng state back
if args.rank == 0: if args.rank == 0:
min_loss = checkpoint['min_loss'] min_loss = state['min_loss']
print('checkpoint of epoch {} loaded from {}'.format( print('checkpoint at epoch {} loaded from {}'.format(
checkpoint['epoch'], args.load_state)) state['epoch'], args.load_state))
del checkpoint del state
else: else:
args.start_epoch = 0 args.start_epoch = 0
if args.rank == 0: if args.rank == 0:
@ -125,7 +126,7 @@ def gpu_worker(local_rank, args):
if args.rank == 0: if args.rank == 0:
args.logger.close() args.logger.close()
checkpoint = { state = {
'epoch': epoch + 1, 'epoch': epoch + 1,
'model': model.state_dict(), 'model': model.state_dict(),
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
@ -134,8 +135,8 @@ def gpu_worker(local_rank, args):
'min_loss': min_loss, 'min_loss': min_loss,
} }
filename='checkpoint.pth' filename='checkpoint.pth'
torch.save(checkpoint, filename) torch.save(state, filename)
del checkpoint del state
if min_loss is None or val_loss < min_loss: if min_loss is None or val_loss < min_loss:
min_loss = val_loss min_loss = val_loss
@ -152,7 +153,7 @@ def train(epoch, loader, model, criterion, optimizer, args):
target = target.to(args.device, non_blocking=True) target = target.to(args.device, non_blocking=True)
output = model(input) output = model(input)
target = narrow_like(target, output) target = narrow_like(target, output) # FIXME pad
loss = criterion(output, target) loss = criterion(output, target)
@ -167,7 +168,6 @@ def train(epoch, loader, model, criterion, optimizer, args):
if args.rank == 0: if args.rank == 0:
args.logger.add_scalar('loss/train', loss.item(), global_step=batch) args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached')
def validate(epoch, loader, model, criterion, args): def validate(epoch, loader, model, criterion, args):
model.eval() model.eval()
@ -180,7 +180,7 @@ def validate(epoch, loader, model, criterion, args):
target = target.to(args.device, non_blocking=True) target = target.to(args.device, non_blocking=True)
output = model(input) output = model(input)
target = narrow_like(target, output) target = narrow_like(target, output) # FIXME pad
loss += criterion(output, target) loss += criterion(output, target)
@ -189,6 +189,4 @@ def validate(epoch, loader, model, criterion, args):
if args.rank == 0: if args.rank == 0:
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1) args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1)
# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached')
return loss.item() return loss.item()

View File

@ -0,0 +1,48 @@
#!/bin/bash
#SBATCH --job-name=dis2dis-test
#SBATCH --output=%x-%j.out
#SBATCH --error=%x-%j.err
#SBATCH --partition=ccm
#SBATCH --exclusive
#SBATCH --nodes=1
#SBATCH --mem=0
#SBATCH --time=1-00:00:00
hostname; pwd; date
module load gcc openmpi2
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
source $HOME/anaconda3/bin/activate torch
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
echo OMP_NUM_THREADS = $OMP_NUM_THREADS
data_root_dir="/mnt/ceph/users/yinli/Quijote"
in_dir="linear"
tgt_dir="nonlin"
test_dirs="0" # FIXME
files="dis/128x???.npy"
in_files="$files"
tgt_files="$files"
srun m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.dis \
--batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth
date

View File

@ -11,7 +11,7 @@
#SBATCH --exclusive #SBATCH --exclusive
#SBATCH --nodes=2 #SBATCH --nodes=2
#SBATCH --mem=0 #SBATCH --mem=0
#SBATCH --time=2-00:00:00 #SBATCH --time=7-00:00:00
hostname; pwd; date hostname; pwd; date
@ -46,7 +46,7 @@ srun m2m.py train \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \ --in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
--epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4 --epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002
# --load-state checkpoint.pth # --load-state checkpoint.pth

View File

@ -0,0 +1,48 @@
#!/bin/bash
#SBATCH --job-name=vel2vel-test
#SBATCH --output=%x-%j.out
#SBATCH --error=%x-%j.err
#SBATCH --partition=ccm
#SBATCH --exclusive
#SBATCH --nodes=1
#SBATCH --mem=0
#SBATCH --time=1-00:00:00
hostname; pwd; date
module load gcc openmpi2
module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1
source $HOME/anaconda3/bin/activate torch
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE
echo OMP_NUM_THREADS = $OMP_NUM_THREADS
data_root_dir="/mnt/ceph/users/yinli/Quijote"
in_dir="linear"
tgt_dir="nonlin"
test_dirs="0" # FIXME
files="vel/128x???.npy"
in_files="$files"
tgt_files="$files"
srun m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.vel \
--batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth
date

View File

@ -11,7 +11,7 @@
#SBATCH --exclusive #SBATCH --exclusive
#SBATCH --nodes=2 #SBATCH --nodes=2
#SBATCH --mem=0 #SBATCH --mem=0
#SBATCH --time=2-00:00:00 #SBATCH --time=7-00:00:00
hostname; pwd; date hostname; pwd; date
@ -46,7 +46,7 @@ srun m2m.py train \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \ --in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
--epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4 --epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002
# --load-state checkpoint.pth # --load-state checkpoint.pth