Add cuda backend to inference
This commit is contained in:
parent
e20a3e3f62
commit
0d4ae3424e
@ -173,6 +173,10 @@ def add_test_args(parser):
|
|||||||
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
||||||
help='comma-sep. list of glob patterns for test target data')
|
help='comma-sep. list of glob patterns for test target data')
|
||||||
|
|
||||||
|
parser.add_argument('--num-threads', type=int,
|
||||||
|
help='number of CPU threads when cuda is unavailable. '
|
||||||
|
'Default is the number of CPUs on the node by slurm')
|
||||||
|
|
||||||
|
|
||||||
def str_list(s):
|
def str_list(s):
|
||||||
return s.split(',')
|
return s.split(',')
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,6 +14,22 @@ from .utils import import_attr, load_model_state_dict
|
|||||||
|
|
||||||
|
|
||||||
def test(args):
|
def test(args):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
warnings.warn('Not parallelized but given more than 1 GPUs')
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
device = torch.device('cuda', 0)
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
else: # CPU multithreading
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
if args.num_threads is None:
|
||||||
|
args.num_threads = int(os.environ['SLURM_CPUS_ON_NODE'])
|
||||||
|
|
||||||
|
torch.set_num_threads(args.num_threads)
|
||||||
|
|
||||||
print('pytorch {}'.format(torch.__version__))
|
print('pytorch {}'.format(torch.__version__))
|
||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
@ -41,6 +59,7 @@ def test(args):
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
style_size = test_dataset.style_size
|
style_size = test_dataset.style_size
|
||||||
@ -50,10 +69,13 @@ def test(args):
|
|||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(style_size, sum(in_chan), sum(out_chan),
|
model = model(style_size, sum(in_chan), sum(out_chan),
|
||||||
scale_factor=args.scale_factor, **args.misc_kwargs)
|
scale_factor=args.scale_factor, **args.misc_kwargs)
|
||||||
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
model.to(device)
|
||||||
criterion = criterion()
|
|
||||||
|
criterion = import_attr(args.criterion, torch.nn, models,
|
||||||
|
callback_at=args.callback_at)
|
||||||
|
criterion = criterion()
|
||||||
|
criterion.to(device)
|
||||||
|
|
||||||
device = torch.device('cpu')
|
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
|
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
|
||||||
print('model state at epoch {} loaded from {}'.format(
|
print('model state at epoch {} loaded from {}'.format(
|
||||||
@ -66,8 +88,21 @@ def test(args):
|
|||||||
for i, data in enumerate(test_loader):
|
for i, data in enumerate(test_loader):
|
||||||
style, input, target = data['style'], data['input'], data['target']
|
style, input, target = data['style'], data['input'], data['target']
|
||||||
|
|
||||||
|
style = style.to(device, non_blocking=True)
|
||||||
|
input = input.to(device, non_blocking=True)
|
||||||
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
|
if i < 5:
|
||||||
|
print('##### sample :', i)
|
||||||
|
print('style shape :', style.shape)
|
||||||
|
print('input shape :', input.shape)
|
||||||
|
print('output shape :', output.shape)
|
||||||
|
print('target shape :', target.shape)
|
||||||
|
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
if i < 5:
|
||||||
|
print('narrowed shape :', output.shape, flush=True)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
@ -2,12 +2,14 @@
|
|||||||
|
|
||||||
#SBATCH --job-name=R2D2
|
#SBATCH --job-name=R2D2
|
||||||
#SBATCH --output=%x-%j.out
|
#SBATCH --output=%x-%j.out
|
||||||
|
|
||||||
#SBATCH --partition=cpu_partition
|
#SBATCH --partition=cpu_partition
|
||||||
|
#SBATCH --nodes=1
|
||||||
#SBATCH --exclusive
|
#SBATCH --exclusive
|
||||||
#SBATCH --nodes=2
|
##SBATCH --partition=gpu_partition
|
||||||
#SBATCH --time=1-00:00:00
|
##SBATCH --gres=gpu:1
|
||||||
|
##SBATCH --ntasks=1
|
||||||
|
##SBATCH --cpus-per-task=8
|
||||||
|
#SBATCH --time=0-01:00:00
|
||||||
|
|
||||||
|
|
||||||
hostname; pwd; date
|
hostname; pwd; date
|
||||||
@ -22,9 +24,6 @@ hostname; pwd; date
|
|||||||
#conda info
|
#conda info
|
||||||
|
|
||||||
|
|
||||||
export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE # use MKL-DNN
|
|
||||||
|
|
||||||
|
|
||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "test/R0-*.npy,test/R1-*.npy" \
|
--test-in-patterns "test/R0-*.npy,test/R1-*.npy" \
|
||||||
--test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \
|
--test-tgt-patterns "test/D0-*.npy,test/D1-*.npy" \
|
||||||
|
Loading…
Reference in New Issue
Block a user