Change noise channel removal from input

This commit is contained in:
Yin Li 2020-02-05 14:40:45 -05:00
parent ef0235f97b
commit a5c48e71b0

View file

@ -254,15 +254,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
target = target.to(args.device, non_blocking=True)
output = model(input)
target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
target = narrow_like(target, output) # FIXME pad
loss = criterion(output, target)
epoch_loss[0] += loss.item()
# generator adversarial loss
if args.adv:
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
@ -351,14 +351,14 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
target = target.to(args.device, non_blocking=True)
output = model(input)
target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
target = narrow_like(target, output) # FIXME pad
loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv:
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,