Add --adv-start as epoch to start adversarial training

This is similar to the deprecated --adv-delay, but specify
absolute epoch (--adv-delay counts from start_epoch in resumed training)
This commit is contained in:
Yin Li 2020-02-13 15:42:27 -05:00
parent e8039dcccc
commit d2840f01b0
3 changed files with 29 additions and 35 deletions

View File

@ -76,6 +76,8 @@ def add_train_args(parser):
help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN')
parser.add_argument('--adv-start', default=0, type=int,
help='epoch to start adversarial training')
parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)')
parser.add_argument('--loss-halflife', default=20, type=float,

View File

@ -3,8 +3,8 @@ import sys
from pprint import pformat
def load_model_state_dict(model, state_dict, strict=True):
bad_keys = model.load_state_dict(state_dict, strict)
def load_model_state_dict(module, state_dict, strict=True):
bad_keys = module.load_state_dict(state_dict, strict)
if len(bad_keys.missing_keys) > 0:
warnings.warn('Missing keys in state_dict:\n{}'.format(

View File

@ -31,8 +31,6 @@ def set_runtime_default_args(args):
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
args.adv_epoch = 0 # epoch when adversary is initiated
def node_worker(args):
set_runtime_default_args(args)
@ -177,20 +175,15 @@ def gpu_worker(local_rank, node, args):
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
if args.adv:
if 'adv_model' in state:
args.adv_epoch = state['adv_epoch']
if args.adv and 'adv_model' in state:
load_model_state_dict(adv_model.module, state['adv_model'],
strict=args.load_state_strict)
else:
args.adv_epoch = start_epoch
torch.set_rng_state(state['rng'].cpu()) # move rng state back
if rank == 0:
min_loss = state['min_loss']
if 'adv_model' not in state and args.adv:
if args.adv and 'adv_model' not in state:
min_loss = None # restarting with adversary wipes the record
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
@ -233,7 +226,7 @@ def gpu_worker(local_rank, node, args):
logger, device, args)
epoch_loss = val_loss
if args.reduce_lr_on_plateau:
if args.reduce_lr_on_plateau and epoch >= args.adv_start:
scheduler.step(epoch_loss[0])
if args.adv:
adv_scheduler.step(epoch_loss[0])
@ -242,7 +235,7 @@ def gpu_worker(local_rank, node, args):
logger.close()
good = min_loss is None or epoch_loss[0] < min_loss[0]
if good:
if good and epoch >= args.adv_start:
min_loss = epoch_loss
state = {
@ -252,10 +245,7 @@ def gpu_worker(local_rank, node, args):
'min_loss': min_loss,
}
if args.adv:
state.update({
'adv_epoch': args.adv_epoch,
'adv_model': adv_model.module.state_dict(),
})
state['adv_model'] = adv_model.module.state_dict()
ckpt_file = 'checkpoint.pth'
state_file = 'state_{}.pth'
torch.save(state, ckpt_file)
@ -296,7 +286,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.adv and args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
@ -305,7 +294,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv:
if args.adv and epoch >= args.adv_start:
if args.cgan:
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
@ -333,7 +322,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
r = loss.item() / (loss_adv.item() + 1e-8)
f = args.loss_fraction
e = epoch - args.adv_epoch
e = epoch - args.adv_start
d = 0.5 ** (e / args.loss_halflife)
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
@ -348,7 +337,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
if rank == 0:
logger.add_scalar('loss/batch/train', loss.item(),
global_step=batch)
if args.adv:
if args.adv and epoch >= args.adv_start:
logger.add_scalar('loss/batch/train/adv/G',
loss_adv.item(), global_step=batch)
logger.add_scalars('loss/batch/train/adv/D', {
@ -362,7 +351,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
if rank == 0:
logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1)
if args.adv:
if args.adv and epoch >= args.adv_start:
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
global_step=epoch+1)
logger.add_scalars('loss/epoch/train/adv/D', {
@ -371,7 +360,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
'real': epoch_loss[4],
}, global_step=epoch+1)
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
global_step =epoch+1)
logger.add_figure('fig/epoch/train/out',
@ -405,7 +396,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.adv and args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest')
@ -414,7 +404,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv:
if args.adv and epoch >= args.adv_start:
if args.cgan:
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
@ -437,7 +427,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
if rank == 0:
logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1)
if args.adv:
if args.adv and epoch >= args.adv_start:
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
global_step=epoch+1)
logger.add_scalars('loss/epoch/val/adv/D', {
@ -446,7 +436,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
'real': epoch_loss[4],
}, global_step=epoch+1)
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
global_step =epoch+1)
logger.add_figure('fig/epoch/val/out',