Change to nearest interpolation from linear in model super-resolution
To be consistent with the data super-resolution in data/fields.py
This commit is contained in:
parent
db3414e11c
commit
53ed5a91f4
@ -302,7 +302,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
scale_factor=model.scale_factor, mode='trilinear')
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
input = narrow_like(input, output)
|
input = narrow_like(input, output)
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
@ -412,7 +412,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
scale_factor=model.scale_factor, mode='trilinear')
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
input = narrow_like(input, output)
|
input = narrow_like(input, output)
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user