From 0211eed0ec1bd453254037ecb5f9c7ab52514003 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sun, 1 Dec 2019 18:53:38 -0500 Subject: [PATCH] Add testing --- map2map/args.py | 29 ++++++++++++--- map2map/data/fields.py | 68 +++++++++++++++++++++------------- map2map/data/norms/__init__.py | 9 +++++ map2map/main.py | 2 +- map2map/models/unet.py | 9 ----- map2map/test.py | 56 ++++++++++++++++++++++++++-- map2map/train.py | 48 ++++++++++++------------ scripts/dis2dis-test.slurm | 48 ++++++++++++++++++++++++ scripts/dis2dis.slurm | 4 +- scripts/vel2vel-test.slurm | 48 ++++++++++++++++++++++++ scripts/vel2vel.slurm | 4 +- 11 files changed, 252 insertions(+), 73 deletions(-) create mode 100644 scripts/dis2dis-test.slurm create mode 100644 scripts/vel2vel-test.slurm diff --git a/map2map/args.py b/map2map/args.py index 3987100..43f41d7 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -21,11 +21,22 @@ def add_common_args(parser): parser.add_argument('--out-channels', type=int, required=True, help='number of output or target channels') parser.add_argument('--norms', type=str_list, help='comma-sep. list ' - 'of normalization functions from map2map.data.norms') + 'of normalization functions from data.norms') parser.add_argument('--criterion', default='MSELoss', help='model criterion from torch.nn') parser.add_argument('--load-state', default='', type=str, help='path to load model, optimizer, rng, etc.') + parser.add_argument('--batches', default=1, type=int, + help='mini-batch size, per GPU in training or in total in testing') + parser.add_argument('--loader-workers', default=0, type=int, + help='number of data loading workers, per GPU in training or ' + 'in total in testing') + parser.add_argument('--pad-or-crop', default=0, type=int_tuple, + help='pad (>0) or crop (<0) the input data; ' + 'can be a int or a 6-tuple (by a comma-sep. list); ' + 'can be asymmetric to align the data with downsample ' + 'and upsample convolutions; ' + 'padding assumes periodic boundary condition') def add_train_args(parser): @@ -39,12 +50,8 @@ def add_train_args(parser): help='comma-sep. list of glob patterns for validation input data') parser.add_argument('--val-tgt-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for validation target data') - parser.add_argument('--epochs', default=128, type=int, + parser.add_argument('--epochs', default=1024, type=int, help='total number of epochs to run') - parser.add_argument('--batches-per-gpu', default=8, type=int, - help='mini-batch size per GPU') - parser.add_argument('--loader-workers-per-gpu', default=4, type=int, - help='number of data loading workers per GPU') parser.add_argument('--augment', action='store_true', help='enable training data augmentation') parser.add_argument('--optimizer', default='Adam', @@ -74,3 +81,13 @@ def add_test_args(parser): def str_list(s): return s.split(',') + + +def int_tuple(t): + t = t.split(',') + t = tuple(int(i) for i in t) + if len(t) == 1: + t = t[0] + elif len(t) != 6: + raise ValueError('pad or crop size must be int or 6-tuple') + return t diff --git a/map2map/data/fields.py b/map2map/data/fields.py index e8fbc25..a1e897a 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -3,7 +3,7 @@ import numpy as np import torch from torch.utils.data import Dataset -from . import norms +from .norms import import_norm class FieldDataset(Dataset): @@ -14,12 +14,15 @@ class FieldDataset(Dataset): Likewise `tgt_patterns` is for target fields. Input and target samples of all fields are matched by sorting the globbed files. + Input fields can be padded (>0) or cropped (<0) with `pad_or_crop`. + Padding assumes periodic boundary condition. + Data augmentations are supported for scalar and vector fields. - `normalize` can be a list of callables to normalize each field. + `norms` can be a list of callables to normalize each field. """ - def __init__(self, in_patterns, tgt_patterns, augment=False, - normalize=None, **kwargs): + def __init__(self, in_patterns, tgt_patterns, pad_or_crop=0, augment=False, + norms=None): in_file_lists = [sorted(glob(p)) for p in in_patterns] self.in_files = list(zip(* in_file_lists)) @@ -29,23 +32,31 @@ class FieldDataset(Dataset): assert len(self.in_files) == len(self.tgt_files), \ 'input and target sample sizes do not match' + if isinstance(pad_or_crop, int): + pad_or_crop = (pad_or_crop,) * 6 + assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \ + 'pad or crop size must be int or 6-tuple' + self.pad_or_crop = np.array((0,) * 2 + pad_or_crop).reshape(4, 2) + self.augment = augment - self.normalize = normalize - if self.normalize is not None: - assert len(in_patterns) == len(self.normalize), \ + if norms is not None: + assert len(in_patterns) == len(norms), \ 'numbers of normalization callables and input fields do not match' - -# self.__dict__.update(kwargs) + norms = [import_norm(norm) for norm in norms if isinstance(norm, str)] + self.norms = norms def __len__(self): return len(self.in_files) def __getitem__(self, idx): - in_fields = [torch.from_numpy(np.load(f)).to(torch.float32) - for f in self.in_files[idx]] - tgt_fields = [torch.from_numpy(np.load(f)).to(torch.float32) - for f in self.tgt_files[idx]] + in_fields = [np.load(f) for f in self.in_files[idx]] + tgt_fields = [np.load(f) for f in self.tgt_files[idx]] + + padcrop(in_fields, self.pad_or_crop) # with numpy + + in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] + tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] if self.augment: flip_axes = torch.randint(2, (3,), dtype=torch.bool) @@ -59,18 +70,8 @@ class FieldDataset(Dataset): perm3d(in_fields, perm_axes) perm3d(tgt_fields, perm_axes) - if self.normalize is not None: - def get_norm(path): - path = path.split('.') - norm = norms - while path: - norm = norm.__dict__[path.pop(0)] - return norm - - for norm, ifield, tfield in zip(self.normalize, in_fields, tgt_fields): - if isinstance(norm, str): - norm = get_norm(norm) - + if self.norms is not None: + for norm, ifield, tfield in zip(self.norms, in_fields, tgt_fields): norm(ifield) norm(tfield) @@ -80,6 +81,22 @@ class FieldDataset(Dataset): return in_fields, tgt_fields +def padcrop(fields, width): + for i, x in enumerate(fields): + if (width >= 0).all(): + x = np.pad(x, width, mode='wrap') + elif (width <= 0).all(): + x = x[..., + -width[1, 0] : width[1, 1], + -width[2, 0] : width[2, 1], + -width[3, 0] : width[3, 1], + ] + else: + raise NotImplementedError('mixed pad-and-crop not supported') + + fields[i] = x + + def flip3d(fields, axes): for i, x in enumerate(fields): if x.size(0) == 3: # flip vector components @@ -90,6 +107,7 @@ def flip3d(fields, axes): fields[i] = x + def perm3d(fields, axes): for i, x in enumerate(fields): if x.size(0) == 3: # permutate vector components diff --git a/map2map/data/norms/__init__.py b/map2map/data/norms/__init__.py index 167c422..b139755 100644 --- a/map2map/data/norms/__init__.py +++ b/map2map/data/norms/__init__.py @@ -1 +1,10 @@ +from importlib import import_module + from . import cosmology + + +def import_norm(path): + mod, func = path.rsplit('.', 1) + mod = import_module('.' + mod, __name__) + func = getattr(mod, func) + return func diff --git a/map2map/main.py b/map2map/main.py index b9615a2..da06840 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -10,4 +10,4 @@ def main(): if args.mode == 'train': train.node_worker(args) elif args.mode == 'test': - pass + test.test(args) diff --git a/map2map/models/unet.py b/map2map/models/unet.py index 89c89f9..c73b859 100644 --- a/map2map/models/unet.py +++ b/map2map/models/unet.py @@ -4,15 +4,6 @@ import torch.nn as nn from .conv import ConvBlock, ResBlock, narrow_like -class DownBlock(ConvBlock): - def __init__(self, in_channels, out_channels, seq='BADBA'): - super().__init__(in_channels, out_channels, seq=seq) - -class UpBlock(ConvBlock): - def __init__(self, in_channels, out_channels, seq='BAUBA'): - super().__init__(in_channels, out_channels, seq=seq) - - class UNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() diff --git a/map2map/test.py b/map2map/test.py index 2eab6bd..4319014 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -1,8 +1,58 @@ -import os - +import numpy as np import torch from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset from .models import UNet, narrow_like + + +def test(args): + test_dataset = FieldDataset( + in_patterns=args.test_in_patterns, + tgt_patterns=args.test_tgt_patterns, + augment=False, + norms=args.norms, + pad_or_crop=args.pad_or_crop, + ) + test_loader = DataLoader( + test_dataset, + batch_size=args.batches, + shuffle=False, + num_workers=args.loader_workers, + ) + + model = UNet(args.in_channels, args.out_channels) + criterion = torch.nn.__dict__[args.criterion]() + + device = torch.device('cpu') + state = torch.load(args.load_state, map_location=device) + from collections import OrderedDict + model_state = OrderedDict() + for k, v in state['model'].items(): + model_k = k.replace('module.', '', 1) # FIXME + model_state[model_k] = v + model.load_state_dict(model_state) + print('model state at epoch {} loaded from {}'.format( + state['epoch'], args.load_state)) + del state + + model.eval() + + with torch.no_grad(): + for i, (input, target) in enumerate(test_loader): + output = model(input) + if args.pad_or_crop > 0: # FIXME + output = narrow_like(output, target) + else: + target = narrow_like(target, output) + + loss = criterion(output, target) + + print('sample {} loss: {}'.format(i, loss)) + + if args.norms is not None: + norm = test_dataset.norms[0] # FIXME + norm(output, undo=True) + + np.savez('{}.npz'.format(i), input=input.numpy(), + output=output.numpy(), target=target.numpy()) diff --git a/map2map/train.py b/map2map/train.py index 659b142..2ae7d87 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -1,6 +1,5 @@ import os import shutil - import torch from torch.multiprocessing import spawn from torch.distributed import init_process_group, destroy_process_group, all_reduce @@ -46,15 +45,16 @@ def gpu_worker(local_rank, args): in_patterns=args.train_in_patterns, tgt_patterns=args.train_tgt_patterns, augment=args.augment, - normalize=args.norms, + norms=args.norms, + pad_or_crop=args.pad_or_crop, ) train_sampler = DistributedSampler(train_dataset, shuffle=True) train_loader = DataLoader( train_dataset, - batch_size=args.batches_per_gpu, + batch_size=args.batches, shuffle=False, sampler=train_sampler, - num_workers=args.loader_workers_per_gpu, + num_workers=args.loader_workers, pin_memory=True ) @@ -62,15 +62,16 @@ def gpu_worker(local_rank, args): in_patterns=args.val_in_patterns, tgt_patterns=args.val_tgt_patterns, augment=False, - normalize=args.norms, + norms=args.norms, + pad_or_crop=args.pad_or_crop, ) val_sampler = DistributedSampler(val_dataset, shuffle=False) val_loader = DataLoader( val_dataset, - batch_size=args.batches_per_gpu, + batch_size=args.batches, shuffle=False, sampler=val_sampler, - num_workers=args.loader_workers_per_gpu, + num_workers=args.loader_workers, pin_memory=True ) @@ -90,17 +91,17 @@ def gpu_worker(local_rank, args): scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) if args.load_state: - checkpoint = torch.load(args.load_state, map_location=args.device) - args.start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - scheduler.load_state_dict(checkpoint['scheduler']) - torch.set_rng_state(checkpoint['rng'].cpu()) # move rng state back + state = torch.load(args.load_state, map_location=args.device) + args.start_epoch = state['epoch'] + model.load_state_dict(state['model']) + optimizer.load_state_dict(state['optimizer']) + scheduler.load_state_dict(state['scheduler']) + torch.set_rng_state(state['rng'].cpu()) # move rng state back if args.rank == 0: - min_loss = checkpoint['min_loss'] - print('checkpoint of epoch {} loaded from {}'.format( - checkpoint['epoch'], args.load_state)) - del checkpoint + min_loss = state['min_loss'] + print('checkpoint at epoch {} loaded from {}'.format( + state['epoch'], args.load_state)) + del state else: args.start_epoch = 0 if args.rank == 0: @@ -125,7 +126,7 @@ def gpu_worker(local_rank, args): if args.rank == 0: args.logger.close() - checkpoint = { + state = { 'epoch': epoch + 1, 'model': model.state_dict(), 'optimizer' : optimizer.state_dict(), @@ -134,8 +135,8 @@ def gpu_worker(local_rank, args): 'min_loss': min_loss, } filename='checkpoint.pth' - torch.save(checkpoint, filename) - del checkpoint + torch.save(state, filename) + del state if min_loss is None or val_loss < min_loss: min_loss = val_loss @@ -152,7 +153,7 @@ def train(epoch, loader, model, criterion, optimizer, args): target = target.to(args.device, non_blocking=True) output = model(input) - target = narrow_like(target, output) + target = narrow_like(target, output) # FIXME pad loss = criterion(output, target) @@ -167,7 +168,6 @@ def train(epoch, loader, model, criterion, optimizer, args): if args.rank == 0: args.logger.add_scalar('loss/train', loss.item(), global_step=batch) -# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached') def validate(epoch, loader, model, criterion, args): model.eval() @@ -180,7 +180,7 @@ def validate(epoch, loader, model, criterion, args): target = target.to(args.device, non_blocking=True) output = model(input) - target = narrow_like(target, output) + target = narrow_like(target, output) # FIXME pad loss += criterion(output, target) @@ -189,6 +189,4 @@ def validate(epoch, loader, model, criterion, args): if args.rank == 0: args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1) -# f'max GPU mem: {torch.cuda.max_memory_allocated()} allocated, {torch.cuda.max_memory_cached()} cached') - return loss.item() diff --git a/scripts/dis2dis-test.slurm b/scripts/dis2dis-test.slurm new file mode 100644 index 0000000..65a71fb --- /dev/null +++ b/scripts/dis2dis-test.slurm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name=dis2dis-test +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err + +#SBATCH --partition=ccm + +#SBATCH --exclusive +#SBATCH --nodes=1 +#SBATCH --mem=0 +#SBATCH --time=1-00:00:00 + + +hostname; pwd; date + + +module load gcc openmpi2 +module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1 + +source $HOME/anaconda3/bin/activate torch + + +export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE +echo OMP_NUM_THREADS = $OMP_NUM_THREADS + + +data_root_dir="/mnt/ceph/users/yinli/Quijote" + +in_dir="linear" +tgt_dir="nonlin" + +test_dirs="0" # FIXME + +files="dis/128x???.npy" +in_files="$files" +tgt_files="$files" + + +srun m2m.py test \ + --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ + --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ + --in-channels 3 --out-channels 3 --norms cosmology.dis \ + --batches 1 --loader-workers 0 --pad-or-crop 40 \ + --load-state best_model.pth + + +date diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm index 36cf91e..61cdc3f 100644 --- a/scripts/dis2dis.slurm +++ b/scripts/dis2dis.slurm @@ -11,7 +11,7 @@ #SBATCH --exclusive #SBATCH --nodes=2 #SBATCH --mem=0 -#SBATCH --time=2-00:00:00 +#SBATCH --time=7-00:00:00 hostname; pwd; date @@ -46,7 +46,7 @@ srun m2m.py train \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --in-channels 3 --out-channels 3 --norms cosmology.dis --augment \ - --epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4 + --epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002 # --load-state checkpoint.pth diff --git a/scripts/vel2vel-test.slurm b/scripts/vel2vel-test.slurm new file mode 100644 index 0000000..c6486fb --- /dev/null +++ b/scripts/vel2vel-test.slurm @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH --job-name=vel2vel-test +#SBATCH --output=%x-%j.out +#SBATCH --error=%x-%j.err + +#SBATCH --partition=ccm + +#SBATCH --exclusive +#SBATCH --nodes=1 +#SBATCH --mem=0 +#SBATCH --time=1-00:00:00 + + +hostname; pwd; date + + +module load gcc openmpi2 +module load cuda/10.1.243_418.87.00 cudnn/v7.6.2-cuda-10.1 + +source $HOME/anaconda3/bin/activate torch + + +export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE +echo OMP_NUM_THREADS = $OMP_NUM_THREADS + + +data_root_dir="/mnt/ceph/users/yinli/Quijote" + +in_dir="linear" +tgt_dir="nonlin" + +test_dirs="0" # FIXME + +files="vel/128x???.npy" +in_files="$files" +tgt_files="$files" + + +srun m2m.py test \ + --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ + --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ + --in-channels 3 --out-channels 3 --norms cosmology.vel \ + --batches 1 --loader-workers 0 --pad-or-crop 40 \ + --load-state best_model.pth + + +date diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm index f765282..2e3caa9 100644 --- a/scripts/vel2vel.slurm +++ b/scripts/vel2vel.slurm @@ -11,7 +11,7 @@ #SBATCH --exclusive #SBATCH --nodes=2 #SBATCH --mem=0 -#SBATCH --time=2-00:00:00 +#SBATCH --time=7-00:00:00 hostname; pwd; date @@ -46,7 +46,7 @@ srun m2m.py train \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --in-channels 3 --out-channels 3 --norms cosmology.vel --augment \ - --epochs 128 --batches-per-gpu 4 --loader-workers-per-gpu 4 + --epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002 # --load-state checkpoint.pth