Add styles to test.py

This commit is contained in:
Yin Li 2021-03-24 15:13:21 -04:00
parent 01c2e45430
commit 3eaca964ed

View File

@ -18,6 +18,7 @@ def test(args):
sys.stdout.flush() sys.stdout.flush()
test_dataset = FieldDataset( test_dataset = FieldDataset(
style_pattern=args.test_style_pattern,
in_patterns=args.test_in_patterns, in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns, tgt_patterns=args.test_tgt_patterns,
in_norms=args.in_norms, in_norms=args.in_norms,
@ -42,10 +43,12 @@ def test(args):
num_workers=args.loader_workers, num_workers=args.loader_workers,
) )
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan style_size = test_dataset.style_size
in_chan = test_dataset.in_chan
out_chan = test_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at) criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
criterion = criterion() criterion = criterion()
@ -61,8 +64,8 @@ def test(args):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for i, (input, target) in enumerate(test_loader): for i, (style, input, target) in enumerate(test_loader):
output = model(input) output = model(input, style)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
loss = criterion(output, target) loss = criterion(output, target)