Add styles to test.py
This commit is contained in:
parent
01c2e45430
commit
3eaca964ed
@ -18,6 +18,7 @@ def test(args):
|
||||
sys.stdout.flush()
|
||||
|
||||
test_dataset = FieldDataset(
|
||||
style_pattern=args.test_style_pattern,
|
||||
in_patterns=args.test_in_patterns,
|
||||
tgt_patterns=args.test_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
@ -42,10 +43,12 @@ def test(args):
|
||||
num_workers=args.loader_workers,
|
||||
)
|
||||
|
||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||
style_size = test_dataset.style_size
|
||||
in_chan = test_dataset.in_chan
|
||||
out_chan = test_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
||||
model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
||||
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
||||
criterion = criterion()
|
||||
|
||||
@ -61,8 +64,8 @@ def test(args):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (input, target) in enumerate(test_loader):
|
||||
output = model(input)
|
||||
for i, (style, input, target) in enumerate(test_loader):
|
||||
output = model(input, style)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
Loading…
Reference in New Issue
Block a user