Change to nearest interpolation from linear in model super-resolution

To be consistent with the data super-resolution in data/fields.py
This commit is contained in:
Yin Li 2020-02-13 11:35:29 -05:00
parent db3414e11c
commit 53ed5a91f4

View File

@ -302,7 +302,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
if args.cgan: if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
scale_factor=model.scale_factor, mode='trilinear') scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output) input = narrow_like(input, output)
output = torch.cat([input, output], dim=1) output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)
@ -412,7 +412,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
if args.cgan: if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
scale_factor=model.scale_factor, mode='trilinear') scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output) input = narrow_like(input, output)
output = torch.cat([input, output], dim=1) output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)