Move map loss computation forward in training
This commit is contained in:
parent
e383ec3977
commit
e8039dcccc
@ -192,8 +192,8 @@ def gpu_worker(local_rank, node, args):
|
||||
min_loss = state['min_loss']
|
||||
if 'adv_model' not in state and args.adv:
|
||||
min_loss = None # restarting with adversary wipes the record
|
||||
print('checkpoint at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state))
|
||||
print('state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state), flush=True)
|
||||
|
||||
del state
|
||||
else:
|
||||
@ -292,21 +292,25 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input)
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
|
||||
# discriminator
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if args.adv and args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
set_requires_grad(adv_model, True)
|
||||
|
||||
eval = adv_model([output.detach(), target])
|
||||
@ -320,11 +324,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
if args.adv:
|
||||
set_requires_grad(adv_model, False)
|
||||
|
||||
eval_out = adv_model(output)
|
||||
@ -372,8 +372,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
}, global_step=epoch+1)
|
||||
|
||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||
logger.add_figure('fig/epoch/train/in',
|
||||
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
||||
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
|
||||
global_step =epoch+1)
|
||||
logger.add_figure('fig/epoch/train/out',
|
||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
||||
@ -401,19 +401,21 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input)
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if args.adv and args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
@ -445,9 +447,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
}, global_step=epoch+1)
|
||||
|
||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||
logger.add_figure('fig/epoch/val/in',
|
||||
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
||||
logger.add_figure('fig/epoch/val',
|
||||
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
|
||||
global_step =epoch+1)
|
||||
logger.add_figure('fig/epoch/val/out',
|
||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
||||
global_step =epoch+1)
|
||||
|
Loading…
Reference in New Issue
Block a user