Change noise channel removal from input
This commit is contained in:
parent
ef0235f97b
commit
a5c48e71b0
@ -254,15 +254,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
target = target.to(args.device, non_blocking=True)
|
target = target.to(args.device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
target = narrow_like(target, output) # FIXME pad
|
if args.noise_chan > 0:
|
||||||
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
|
|
||||||
|
target = narrow_like(target, output) # FIXME pad
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
# generator adversarial loss
|
# generator adversarial loss
|
||||||
if args.adv:
|
if args.adv:
|
||||||
if args.noise_chan > 0:
|
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
@ -351,14 +351,14 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||||||
target = target.to(args.device, non_blocking=True)
|
target = target.to(args.device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
target = narrow_like(target, output) # FIXME pad
|
if args.noise_chan > 0:
|
||||||
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
|
|
||||||
|
target = narrow_like(target, output) # FIXME pad
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
if args.noise_chan > 0:
|
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
|
Loading…
Reference in New Issue
Block a user