Move map loss computation forward in training

This commit is contained in:
Yin Li 2020-02-13 14:40:51 -05:00
parent e383ec3977
commit e8039dcccc

View file

@ -192,8 +192,8 @@ def gpu_worker(local_rank, node, args):
min_loss = state['min_loss']
if 'adv_model' not in state and args.adv:
min_loss = None # restarting with adversary wipes the record
print('checkpoint at epoch {} loaded from {}'.format(
state['epoch'], args.load_state))
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
del state
else:
@ -292,21 +292,25 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
target = target.to(device, non_blocking=True)
output = model(input)
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.adv and args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
loss = criterion(output, target)
epoch_loss[0] += loss.item()
# discriminator
if args.adv:
if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
# discriminator
set_requires_grad(adv_model, True)
eval = adv_model([output.detach(), target])
@ -320,11 +324,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
adv_loss.backward()
adv_optimizer.step()
loss = criterion(output, target)
epoch_loss[0] += loss.item()
# generator adversarial loss
if args.adv:
# generator adversarial loss
set_requires_grad(adv_model, False)
eval_out = adv_model(output)
@ -372,8 +372,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
}, global_step=epoch+1)
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
logger.add_figure('fig/epoch/train/in',
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
global_step =epoch+1)
logger.add_figure('fig/epoch/train/out',
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:]),
@ -401,19 +401,21 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
target = target.to(device, non_blocking=True)
output = model(input)
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.adv and args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv:
if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
input = narrow_like(input, output)
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
@ -445,9 +447,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
}, global_step=epoch+1)
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
logger.add_figure('fig/epoch/val/in',
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
logger.add_figure('fig/epoch/val',
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
global_step =epoch+1)
logger.add_figure('fig/epoch/val/out',
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:]),
global_step =epoch+1)