diff --git a/map2map/args.py b/map2map/args.py index dcce6a8..835517d 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -173,6 +173,10 @@ def add_test_args(parser): parser.add_argument('--test-tgt-patterns', type=str_list, required=True, help='comma-sep. list of glob patterns for test target data') + parser.add_argument('--num-threads', type=int, + help='number of CPU threads when cuda is unavailable. ' + 'Default is the number of CPUs on the node by slurm') + def str_list(s): return s.split(',') diff --git a/map2map/test.py b/map2map/test.py index 1cdd6c8..1d8e3f7 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -1,4 +1,6 @@ +import os import sys +import warnings from pprint import pprint import numpy as np import torch @@ -12,6 +14,22 @@ from .utils import import_attr, load_model_state_dict def test(args): + if torch.cuda.is_available(): + if torch.cuda.device_count() > 1: + warnings.warn('Not parallelized but given more than 1 GPUs') + + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + device = torch.device('cuda', 0) + + torch.backends.cudnn.benchmark = True + else: # CPU multithreading + device = torch.device('cpu') + + if args.num_threads is None: + args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE']) + + torch.set_num_threads(args.num_threads) + print('pytorch {}'.format(torch.__version__)) pprint(vars(args)) sys.stdout.flush() @@ -41,6 +59,7 @@ def test(args): batch_size=args.batch_size, shuffle=False, num_workers=args.loader_workers, + pin_memory=True, ) style_size = test_dataset.style_size @@ -50,10 +69,13 @@ def test(args): model = import_attr(args.model, models, callback_at=args.callback_at) model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor, **args.misc_kwargs) - criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at) - criterion = criterion() + model.to(device) + + criterion = import_attr(args.criterion, torch.nn, models, + callback_at=args.callback_at) + criterion = criterion() + criterion.to(device) - device = torch.device('cpu') state = torch.load(args.load_state, map_location=device) load_model_state_dict(model, state['model'], strict=args.load_state_strict) print('model state at epoch {} loaded from {}'.format( @@ -66,8 +88,21 @@ def test(args): for i, data in enumerate(test_loader): style, input, target = data['style'], data['input'], data['target'] + style = style.to(device, non_blocking=True) + input = input.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + output = model(input, style) + if i < 5: + print('##### sample :', i) + print('style shape :', style.shape) + print('input shape :', input.shape) + print('output shape :', output.shape) + print('target shape :', target.shape) + input, output, target = narrow_cast(input, output, target) + if i < 5: + print('narrowed shape :', output.shape, flush=True) loss = criterion(output, target) diff --git a/scripts/example-test.slurm b/scripts/example-test.slurm index 4e7398b..969e1c0 100644 --- a/scripts/example-test.slurm +++ b/scripts/example-test.slurm @@ -2,12 +2,14 @@ #SBATCH --job-name=R2D2 #SBATCH --output=%x-%j.out - #SBATCH --partition=cpu_partition - +#SBATCH --nodes=1 #SBATCH --exclusive -#SBATCH --nodes=2 -#SBATCH --time=1-00:00:00 +##SBATCH --partition=gpu_partition +##SBATCH --gres=gpu:1 +##SBATCH --ntasks=1 +##SBATCH --cpus-per-task=8 +#SBATCH --time=0-01:00:00 hostname; pwd; date @@ -22,9 +24,6 @@ hostname; pwd; date #conda info -export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE # use MKL-DNN - - m2m.py test \ --test-in-patterns "test/R0-*.npy,test/R1-*.npy" \ --test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \