Add --adv-start as epoch to start adversarial training
This is similar to the deprecated --adv-delay, but specify absolute epoch (--adv-delay counts from start_epoch in resumed training)
This commit is contained in:
parent
e8039dcccc
commit
d2840f01b0
@ -76,6 +76,8 @@ def add_train_args(parser):
|
||||
help='enable minimum reduction in adversarial criterion')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--adv-start', default=0, type=int,
|
||||
help='epoch to start adversarial training')
|
||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||
help='final fraction of loss (vs adv-loss)')
|
||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||
|
@ -3,8 +3,8 @@ import sys
|
||||
from pprint import pformat
|
||||
|
||||
|
||||
def load_model_state_dict(model, state_dict, strict=True):
|
||||
bad_keys = model.load_state_dict(state_dict, strict)
|
||||
def load_model_state_dict(module, state_dict, strict=True):
|
||||
bad_keys = module.load_state_dict(state_dict, strict)
|
||||
|
||||
if len(bad_keys.missing_keys) > 0:
|
||||
warnings.warn('Missing keys in state_dict:\n{}'.format(
|
||||
|
@ -31,8 +31,6 @@ def set_runtime_default_args(args):
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
args.adv_epoch = 0 # epoch when adversary is initiated
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
set_runtime_default_args(args)
|
||||
@ -177,20 +175,15 @@ def gpu_worker(local_rank, node, args):
|
||||
load_model_state_dict(model.module, state['model'],
|
||||
strict=args.load_state_strict)
|
||||
|
||||
if args.adv:
|
||||
if 'adv_model' in state:
|
||||
args.adv_epoch = state['adv_epoch']
|
||||
|
||||
if args.adv and 'adv_model' in state:
|
||||
load_model_state_dict(adv_model.module, state['adv_model'],
|
||||
strict=args.load_state_strict)
|
||||
else:
|
||||
args.adv_epoch = start_epoch
|
||||
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
|
||||
if rank == 0:
|
||||
min_loss = state['min_loss']
|
||||
if 'adv_model' not in state and args.adv:
|
||||
if args.adv and 'adv_model' not in state:
|
||||
min_loss = None # restarting with adversary wipes the record
|
||||
print('state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state), flush=True)
|
||||
@ -233,7 +226,7 @@ def gpu_worker(local_rank, node, args):
|
||||
logger, device, args)
|
||||
epoch_loss = val_loss
|
||||
|
||||
if args.reduce_lr_on_plateau:
|
||||
if args.reduce_lr_on_plateau and epoch >= args.adv_start:
|
||||
scheduler.step(epoch_loss[0])
|
||||
if args.adv:
|
||||
adv_scheduler.step(epoch_loss[0])
|
||||
@ -242,7 +235,7 @@ def gpu_worker(local_rank, node, args):
|
||||
logger.close()
|
||||
|
||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if good:
|
||||
if good and epoch >= args.adv_start:
|
||||
min_loss = epoch_loss
|
||||
|
||||
state = {
|
||||
@ -252,10 +245,7 @@ def gpu_worker(local_rank, node, args):
|
||||
'min_loss': min_loss,
|
||||
}
|
||||
if args.adv:
|
||||
state.update({
|
||||
'adv_epoch': args.adv_epoch,
|
||||
'adv_model': adv_model.module.state_dict(),
|
||||
})
|
||||
state['adv_model'] = adv_model.module.state_dict()
|
||||
ckpt_file = 'checkpoint.pth'
|
||||
state_file = 'state_{}.pth'
|
||||
torch.save(state, ckpt_file)
|
||||
@ -296,7 +286,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if args.adv and args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
@ -305,7 +294,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv:
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
if args.cgan:
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
@ -333,7 +322,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
|
||||
r = loss.item() / (loss_adv.item() + 1e-8)
|
||||
f = args.loss_fraction
|
||||
e = epoch - args.adv_epoch
|
||||
e = epoch - args.adv_start
|
||||
d = 0.5 ** (e / args.loss_halflife)
|
||||
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
|
||||
|
||||
@ -348,7 +337,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/batch/train', loss.item(),
|
||||
global_step=batch)
|
||||
if args.adv:
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('loss/batch/train/adv/G',
|
||||
loss_adv.item(), global_step=batch)
|
||||
logger.add_scalars('loss/batch/train/adv/D', {
|
||||
@ -362,7 +351,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
if args.adv:
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalars('loss/epoch/train/adv/D', {
|
||||
@ -371,7 +360,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
'real': epoch_loss[4],
|
||||
}, global_step=epoch+1)
|
||||
|
||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||
skip_chan = 0
|
||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||
skip_chan = sum(args.in_chan)
|
||||
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
|
||||
global_step =epoch+1)
|
||||
logger.add_figure('fig/epoch/train/out',
|
||||
@ -405,7 +396,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if args.adv and args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
@ -414,7 +404,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv:
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
if args.cgan:
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
@ -437,7 +427,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
if args.adv:
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalars('loss/epoch/val/adv/D', {
|
||||
@ -446,7 +436,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
'real': epoch_loss[4],
|
||||
}, global_step=epoch+1)
|
||||
|
||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||
skip_chan = 0
|
||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||
skip_chan = sum(args.in_chan)
|
||||
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
|
||||
global_step =epoch+1)
|
||||
logger.add_figure('fig/epoch/val/out',
|
||||
|
Loading…
Reference in New Issue
Block a user