Add styles to test.py
This commit is contained in:
parent
01c2e45430
commit
3eaca964ed
@ -18,6 +18,7 @@ def test(args):
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
test_dataset = FieldDataset(
|
test_dataset = FieldDataset(
|
||||||
|
style_pattern=args.test_style_pattern,
|
||||||
in_patterns=args.test_in_patterns,
|
in_patterns=args.test_in_patterns,
|
||||||
tgt_patterns=args.test_tgt_patterns,
|
tgt_patterns=args.test_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
@ -42,10 +43,12 @@ def test(args):
|
|||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
style_size = test_dataset.style_size
|
||||||
|
in_chan = test_dataset.in_chan
|
||||||
|
out_chan = test_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models, callback_at=args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
model = model(style_size, sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
||||||
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
@ -61,8 +64,8 @@ def test(args):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (input, target) in enumerate(test_loader):
|
for i, (style, input, target) in enumerate(test_loader):
|
||||||
output = model(input)
|
output = model(input, style)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
Loading…
Reference in New Issue
Block a user