Add --adv-start as epoch to start adversarial training
This is similar to the deprecated --adv-delay, but specify absolute epoch (--adv-delay counts from start_epoch in resumed training)
This commit is contained in:
parent
e8039dcccc
commit
d2840f01b0
@ -76,6 +76,8 @@ def add_train_args(parser):
|
|||||||
help='enable minimum reduction in adversarial criterion')
|
help='enable minimum reduction in adversarial criterion')
|
||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
|
parser.add_argument('--adv-start', default=0, type=int,
|
||||||
|
help='epoch to start adversarial training')
|
||||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='final fraction of loss (vs adv-loss)')
|
help='final fraction of loss (vs adv-loss)')
|
||||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||||
|
@ -3,8 +3,8 @@ import sys
|
|||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
|
||||||
def load_model_state_dict(model, state_dict, strict=True):
|
def load_model_state_dict(module, state_dict, strict=True):
|
||||||
bad_keys = model.load_state_dict(state_dict, strict)
|
bad_keys = module.load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
if len(bad_keys.missing_keys) > 0:
|
if len(bad_keys.missing_keys) > 0:
|
||||||
warnings.warn('Missing keys in state_dict:\n{}'.format(
|
warnings.warn('Missing keys in state_dict:\n{}'.format(
|
||||||
|
@ -31,8 +31,6 @@ def set_runtime_default_args(args):
|
|||||||
if args.adv_weight_decay is None:
|
if args.adv_weight_decay is None:
|
||||||
args.adv_weight_decay = args.weight_decay
|
args.adv_weight_decay = args.weight_decay
|
||||||
|
|
||||||
args.adv_epoch = 0 # epoch when adversary is initiated
|
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
set_runtime_default_args(args)
|
set_runtime_default_args(args)
|
||||||
@ -177,20 +175,15 @@ def gpu_worker(local_rank, node, args):
|
|||||||
load_model_state_dict(model.module, state['model'],
|
load_model_state_dict(model.module, state['model'],
|
||||||
strict=args.load_state_strict)
|
strict=args.load_state_strict)
|
||||||
|
|
||||||
if args.adv:
|
if args.adv and 'adv_model' in state:
|
||||||
if 'adv_model' in state:
|
load_model_state_dict(adv_model.module, state['adv_model'],
|
||||||
args.adv_epoch = state['adv_epoch']
|
strict=args.load_state_strict)
|
||||||
|
|
||||||
load_model_state_dict(adv_model.module, state['adv_model'],
|
|
||||||
strict=args.load_state_strict)
|
|
||||||
else:
|
|
||||||
args.adv_epoch = start_epoch
|
|
||||||
|
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
if 'adv_model' not in state and args.adv:
|
if args.adv and 'adv_model' not in state:
|
||||||
min_loss = None # restarting with adversary wipes the record
|
min_loss = None # restarting with adversary wipes the record
|
||||||
print('state at epoch {} loaded from {}'.format(
|
print('state at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state), flush=True)
|
state['epoch'], args.load_state), flush=True)
|
||||||
@ -233,7 +226,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
logger, device, args)
|
logger, device, args)
|
||||||
epoch_loss = val_loss
|
epoch_loss = val_loss
|
||||||
|
|
||||||
if args.reduce_lr_on_plateau:
|
if args.reduce_lr_on_plateau and epoch >= args.adv_start:
|
||||||
scheduler.step(epoch_loss[0])
|
scheduler.step(epoch_loss[0])
|
||||||
if args.adv:
|
if args.adv:
|
||||||
adv_scheduler.step(epoch_loss[0])
|
adv_scheduler.step(epoch_loss[0])
|
||||||
@ -242,7 +235,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||||
if good:
|
if good and epoch >= args.adv_start:
|
||||||
min_loss = epoch_loss
|
min_loss = epoch_loss
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
@ -252,10 +245,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
'min_loss': min_loss,
|
'min_loss': min_loss,
|
||||||
}
|
}
|
||||||
if args.adv:
|
if args.adv:
|
||||||
state.update({
|
state['adv_model'] = adv_model.module.state_dict()
|
||||||
'adv_epoch': args.adv_epoch,
|
|
||||||
'adv_model': adv_model.module.state_dict(),
|
|
||||||
})
|
|
||||||
ckpt_file = 'checkpoint.pth'
|
ckpt_file = 'checkpoint.pth'
|
||||||
state_file = 'state_{}.pth'
|
state_file = 'state_{}.pth'
|
||||||
torch.save(state, ckpt_file)
|
torch.save(state, ckpt_file)
|
||||||
@ -296,16 +286,15 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
if args.noise_chan > 0:
|
if args.noise_chan > 0:
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
if args.adv and args.cgan:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
input = F.interpolate(input,
|
||||||
input = F.interpolate(input,
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
|
||||||
input = narrow_like(input, output)
|
input = narrow_like(input, output)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv:
|
if args.adv and epoch >= args.adv_start:
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
@ -333,7 +322,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
|
|
||||||
r = loss.item() / (loss_adv.item() + 1e-8)
|
r = loss.item() / (loss_adv.item() + 1e-8)
|
||||||
f = args.loss_fraction
|
f = args.loss_fraction
|
||||||
e = epoch - args.adv_epoch
|
e = epoch - args.adv_start
|
||||||
d = 0.5 ** (e / args.loss_halflife)
|
d = 0.5 ** (e / args.loss_halflife)
|
||||||
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
|
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
|
||||||
|
|
||||||
@ -348,7 +337,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/batch/train', loss.item(),
|
logger.add_scalar('loss/batch/train', loss.item(),
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
if args.adv:
|
if args.adv and epoch >= args.adv_start:
|
||||||
logger.add_scalar('loss/batch/train/adv/G',
|
logger.add_scalar('loss/batch/train/adv/G',
|
||||||
loss_adv.item(), global_step=batch)
|
loss_adv.item(), global_step=batch)
|
||||||
logger.add_scalars('loss/batch/train/adv/D', {
|
logger.add_scalars('loss/batch/train/adv/D', {
|
||||||
@ -362,7 +351,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
if args.adv:
|
if args.adv and epoch >= args.adv_start:
|
||||||
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
logger.add_scalars('loss/epoch/train/adv/D', {
|
logger.add_scalars('loss/epoch/train/adv/D', {
|
||||||
@ -371,7 +360,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
'real': epoch_loss[4],
|
'real': epoch_loss[4],
|
||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
skip_chan = 0
|
||||||
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
|
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
|
||||||
global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
logger.add_figure('fig/epoch/train/out',
|
logger.add_figure('fig/epoch/train/out',
|
||||||
@ -405,16 +396,15 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
if args.noise_chan > 0:
|
if args.noise_chan > 0:
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
if args.adv and args.cgan:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
input = F.interpolate(input,
|
||||||
input = F.interpolate(input,
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
|
||||||
input = narrow_like(input, output)
|
input = narrow_like(input, output)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv:
|
if args.adv and epoch >= args.adv_start:
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
@ -437,7 +427,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
if args.adv:
|
if args.adv and epoch >= args.adv_start:
|
||||||
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
logger.add_scalars('loss/epoch/val/adv/D', {
|
logger.add_scalars('loss/epoch/val/adv/D', {
|
||||||
@ -446,7 +436,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
'real': epoch_loss[4],
|
'real': epoch_loss[4],
|
||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
skip_chan = 0
|
||||||
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
|
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
|
||||||
global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
logger.add_figure('fig/epoch/val/out',
|
logger.add_figure('fig/epoch/val/out',
|
||||||
|
Loading…
Reference in New Issue
Block a user