diff --git a/map2map/test.py b/map2map/test.py index cb32a2b..a5d1d42 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -18,6 +18,7 @@ def test(args): sys.stdout.flush() test_dataset = FieldDataset( + style_pattern=args.test_style_pattern, in_patterns=args.test_in_patterns, tgt_patterns=args.test_tgt_patterns, in_norms=args.in_norms, @@ -42,10 +43,12 @@ def test(args): num_workers=args.loader_workers, ) - in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan + style_size = test_dataset.style_size + in_chan = test_dataset.in_chan + out_chan = test_dataset.tgt_chan model = import_attr(args.model, models, callback_at=args.callback_at) - model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) + model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at) criterion = criterion() @@ -61,8 +64,8 @@ def test(args): model.eval() with torch.no_grad(): - for i, (input, target) in enumerate(test_loader): - output = model(input) + for i, (style, input, target) in enumerate(test_loader): + output = model(input, style) input, output, target = narrow_cast(input, output, target) loss = criterion(output, target)