Move map loss computation forward in training
This commit is contained in:
parent
e383ec3977
commit
e8039dcccc
@ -192,8 +192,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
if 'adv_model' not in state and args.adv:
|
if 'adv_model' not in state and args.adv:
|
||||||
min_loss = None # restarting with adversary wipes the record
|
min_loss = None # restarting with adversary wipes the record
|
||||||
print('checkpoint at epoch {} loaded from {}'.format(
|
print('state at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state), flush=True)
|
||||||
|
|
||||||
del state
|
del state
|
||||||
else:
|
else:
|
||||||
@ -292,21 +292,25 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.noise_chan > 0:
|
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
|
if args.noise_chan > 0:
|
||||||
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
|
if args.adv and args.cgan:
|
||||||
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
|
input = F.interpolate(input,
|
||||||
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
|
input = narrow_like(input, output)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
# discriminator
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
|
||||||
input = F.interpolate(input,
|
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
|
||||||
input = narrow_like(input, output)
|
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
|
|
||||||
|
# discriminator
|
||||||
set_requires_grad(adv_model, True)
|
set_requires_grad(adv_model, True)
|
||||||
|
|
||||||
eval = adv_model([output.detach(), target])
|
eval = adv_model([output.detach(), target])
|
||||||
@ -320,11 +324,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
adv_loss.backward()
|
adv_loss.backward()
|
||||||
adv_optimizer.step()
|
adv_optimizer.step()
|
||||||
|
|
||||||
loss = criterion(output, target)
|
# generator adversarial loss
|
||||||
epoch_loss[0] += loss.item()
|
|
||||||
|
|
||||||
# generator adversarial loss
|
|
||||||
if args.adv:
|
|
||||||
set_requires_grad(adv_model, False)
|
set_requires_grad(adv_model, False)
|
||||||
|
|
||||||
eval_out = adv_model(output)
|
eval_out = adv_model(output)
|
||||||
@ -372,8 +372,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||||
logger.add_figure('fig/epoch/train/in',
|
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
|
||||||
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
logger.add_figure('fig/epoch/train/out',
|
logger.add_figure('fig/epoch/train/out',
|
||||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
||||||
@ -401,19 +401,21 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.noise_chan > 0:
|
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
|
if args.noise_chan > 0:
|
||||||
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
|
if args.adv and args.cgan:
|
||||||
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
|
input = F.interpolate(input,
|
||||||
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
|
input = narrow_like(input, output)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
|
||||||
input = F.interpolate(input,
|
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
|
||||||
input = narrow_like(input, output)
|
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
|
|
||||||
@ -445,9 +447,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||||
logger.add_figure('fig/epoch/val/in',
|
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
|
||||||
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
logger.add_figure('fig/epoch/val',
|
logger.add_figure('fig/epoch/val/out',
|
||||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
||||||
global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
|
Loading…
Reference in New Issue
Block a user