Change delayed adversary to a smooth annealing scheme

This commit is contained in:
Yin Li 2020-02-06 19:04:30 -05:00
parent 291dfb24b3
commit cd63324724
3 changed files with 109 additions and 79 deletions

View File

@ -71,9 +71,10 @@ def add_train_args(parser):
help='enable minimum reduction in adversarial criterion') help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true', parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN') help='enable conditional GAN')
parser.add_argument('--adv-delay', default=0, type=int, parser.add_argument('--loss-halflife', default=10, type=float,
help='epoch before updating the generator with adversarial loss, ' help='half-life (epoch) to anneal loss while enhancing adv-loss')
'and the learning rate with scheduler') parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)')
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')

View File

@ -10,8 +10,9 @@ class PatchGAN(nn.Module):
self.convs = nn.Sequential( self.convs = nn.Sequential(
ConvBlock(in_chan, 32, seq='CA'), ConvBlock(in_chan, 32, seq='CA'),
ConvBlock(32, 64, seq='CBA'), ConvBlock(32, 64, seq='CBA'),
ConvBlock(64, 128, seq='CBA'), ConvBlock(64, seq='CBA'),
nn.Conv3d(128, out_chan, 1) ConvBlock(64, 32, seq='CBA'),
nn.Conv3d(32, out_chan, 1)
) )
def forward(self, x): def forward(self, x):

View File

@ -16,7 +16,24 @@ from .models import narrow_like
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
def set_runtime_default_args(args):
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.adv = args.adv_model is not None
if args.adv:
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
args.adv_epoch = 0 # epoch when adversary is initiated
def node_worker(args): def node_worker(args):
set_runtime_default_args(args)
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker? torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
#torch.backends.cudnn.deterministic = True # NOTE: test perf #torch.backends.cudnn.deterministic = True # NOTE: test perf
@ -27,22 +44,21 @@ def node_worker(args):
node = int(os.environ['SLURM_NODEID']) node = int(os.environ['SLURM_NODEID'])
if node == 0: if node == 0:
print(args) print(args)
args.node = node
spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node) spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, args): def gpu_worker(local_rank, node, args):
args.device = torch.device('cuda', local_rank) device = torch.device('cuda', local_rank)
torch.cuda.device(args.device) torch.cuda.device(device)
args.rank = args.gpus_per_node * args.node + local_rank rank = args.gpus_per_node * node + local_rank
dist.init_process_group( dist.init_process_group(
backend=args.dist_backend, backend=args.dist_backend,
init_method='env://', init_method='env://',
world_size=args.world_size, world_size=args.world_size,
rank=args.rank rank=rank,
) )
train_dataset = FieldDataset( train_dataset = FieldDataset(
@ -59,11 +75,9 @@ def gpu_worker(local_rank, args):
shuffle=args.div_data, shuffle=args.div_data,
sampler=None if args.div_data else train_sampler, sampler=None if args.div_data else train_sampler,
num_workers=args.loader_workers, num_workers=args.loader_workers,
pin_memory=True pin_memory=True,
) )
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
if args.val: if args.val:
val_dataset = FieldDataset( val_dataset = FieldDataset(
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
@ -80,20 +94,20 @@ def gpu_worker(local_rank, args):
shuffle=False, shuffle=False,
sampler=None if args.div_data else val_sampler, sampler=None if args.div_data else val_sampler,
num_workers=args.loader_workers, num_workers=args.loader_workers,
pin_memory=True pin_memory=True,
) )
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model) model = getattr(models, args.model)
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan)) model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan))
model.to(args.device) model.to(device)
model = DistributedDataParallel(model, device_ids=[args.device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
criterion = getattr(torch.nn, args.criterion) criterion = getattr(torch.nn, args.criterion)
criterion = criterion() criterion = criterion()
criterion.to(args.device) criterion.to(device)
optimizer = getattr(torch.optim, args.optimizer) optimizer = getattr(torch.optim, args.optimizer)
optimizer = optimizer( optimizer = optimizer(
@ -107,25 +121,19 @@ def gpu_worker(local_rank, args):
factor=0.1, patience=10, verbose=True) factor=0.1, patience=10, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
args.adv = args.adv_model is not None
if args.adv: if args.adv:
adv_model = getattr(models, args.adv_model) adv_model = getattr(models, args.adv_model)
adv_model = adv_model_wrapper(adv_model) adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan) adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1) if args.cgan else sum(args.out_chan), 1)
adv_model.to(args.device) adv_model.to(device)
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device], adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
adv_criterion = getattr(torch.nn, args.adv_criterion) adv_criterion = getattr(torch.nn, args.adv_criterion)
adv_criterion = adv_criterion_wrapper(adv_criterion) adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean') adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion.to(args.device) adv_criterion.to(device)
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
adv_optimizer = getattr(torch.optim, args.optimizer) adv_optimizer = getattr(torch.optim, args.optimizer)
adv_optimizer = adv_optimizer( adv_optimizer = adv_optimizer(
@ -138,23 +146,32 @@ def gpu_worker(local_rank, args):
factor=0.1, patience=10, verbose=True) factor=0.1, patience=10, verbose=True)
if args.load_state: if args.load_state:
state = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=device)
args.start_epoch = state['epoch']
args.adv_delay += args.start_epoch start_epoch = state['epoch']
model.module.load_state_dict(state['model']) model.module.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer']) optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler']) scheduler.load_state_dict(state['scheduler'])
if 'adv_model' in state and args.adv: if 'adv_model' in state and args.adv:
args.adv_epoch = state['adv_epoch']
adv_model.module.load_state_dict(state['adv_model']) adv_model.module.load_state_dict(state['adv_model'])
adv_optimizer.load_state_dict(state['adv_optimizer']) adv_optimizer.load_state_dict(state['adv_optimizer'])
adv_scheduler.load_state_dict(state['adv_scheduler']) adv_scheduler.load_state_dict(state['adv_scheduler'])
elif 'adv_model' not in state and args.adv:
args.adv_epoch = start_epoch
torch.set_rng_state(state['rng'].cpu()) # move rng state back torch.set_rng_state(state['rng'].cpu()) # move rng state back
if args.rank == 0:
if rank == 0:
min_loss = state['min_loss'] min_loss = state['min_loss']
if 'adv_model' not in state and args.adv: if 'adv_model' not in state and args.adv:
min_loss = None # restarting with adversary wipes the record min_loss = None # restarting with adversary wipes the record
print('checkpoint at epoch {} loaded from {}'.format( print('checkpoint at epoch {} loaded from {}'.format(
state['epoch'], args.load_state)) state['epoch'], args.load_state))
del state del state
else: else:
# def init_weights(m): # def init_weights(m):
@ -166,46 +183,46 @@ def gpu_worker(local_rank, args):
# m.bias.data.fill_(0) # m.bias.data.fill_(0)
# model.apply(init_weights) # model.apply(init_weights)
# #
args.start_epoch = 0 start_epoch = 0
if args.rank == 0:
if rank == 0:
min_loss = None min_loss = None
torch.backends.cudnn.benchmark = True # NOTE: test perf torch.backends.cudnn.benchmark = True # NOTE: test perf
if args.rank == 0: if rank == 0:
args.logger = SummaryWriter() logger = SummaryWriter()
for epoch in range(args.start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
if not args.div_data: if not args.div_data:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, train_loss = train(epoch, train_loader,
model, criterion, optimizer, scheduler, model, criterion, optimizer, scheduler,
adv_model, adv_criterion, adv_optimizer, adv_scheduler, adv_model, adv_criterion, adv_optimizer, adv_scheduler,
args) logger, device, args)
epoch_loss = train_loss epoch_loss = train_loss
if args.val: if args.val:
val_loss = validate(epoch, val_loader, val_loss = validate(epoch, val_loader,
model, criterion, adv_model, adv_criterion, model, criterion, adv_model, adv_criterion,
args) logger, device, args)
epoch_loss = val_loss epoch_loss = val_loss
if epoch >= args.adv_delay: if epoch >= args.loss_halflife:
scheduler.step(epoch_loss[0]) scheduler.step(epoch_loss[0])
if args.adv: if args.adv:
adv_scheduler.step(epoch_loss[0]) adv_scheduler.step(epoch_loss[0])
else: else:
scheduler.last_epoch = epoch scheduler.last_epoch = epoch # HACK due to lack of better option
if args.adv: if args.adv:
adv_scheduler.last_epoch = epoch adv_scheduler.last_epoch = epoch
if args.rank == 0: if rank == 0:
print(end='', flush=True) logger.close()
args.logger.close()
is_best = min_loss is None or epoch_loss[0] < min_loss[0] is_best = min_loss is None or epoch_loss[0] < min_loss[0]
if is_best and epoch >= args.adv_delay: if is_best and epoch >= args.loss_halflife:
min_loss = epoch_loss min_loss = epoch_loss
state = { state = {
@ -218,6 +235,7 @@ def gpu_worker(local_rank, args):
} }
if args.adv: if args.adv:
state.update({ state.update({
'adv_epoch': args.adv_epoch,
'adv_model': adv_model.module.state_dict(), 'adv_model': adv_model.module.state_dict(),
'adv_optimizer': adv_optimizer.state_dict(), 'adv_optimizer': adv_optimizer.state_dict(),
'adv_scheduler': adv_scheduler.state_dict(), 'adv_scheduler': adv_scheduler.state_dict(),
@ -236,22 +254,26 @@ def gpu_worker(local_rank, args):
def train(epoch, loader, model, criterion, optimizer, scheduler, def train(epoch, loader, model, criterion, optimizer, scheduler,
adv_model, adv_criterion, adv_optimizer, adv_scheduler, args): adv_model, adv_criterion, adv_optimizer, adv_scheduler,
logger, device, args):
model.train() model.train()
if args.adv: if args.adv:
adv_model.train() adv_model.train()
rank = dist.get_rank()
world_size = dist.get_world_size()
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real # loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
# loss: generator (model) supervised loss # loss: generator (model) supervised loss
# loss_adv: generator (model) adversarial loss # loss_adv: generator (model) adversarial loss
# adv_loss: discriminator (adv_model) loss # adv_loss: discriminator (adv_model) loss
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
real = torch.ones(1, dtype=torch.float32, device=args.device) real = torch.ones(1, dtype=torch.float32, device=device)
fake = torch.zeros(1, dtype=torch.float32, device=args.device) fake = torch.zeros(1, dtype=torch.float32, device=device)
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
input = input.to(args.device, non_blocking=True) input = input.to(device, non_blocking=True)
target = target.to(args.device, non_blocking=True) target = target.to(device, non_blocking=True)
output = model(input) output = model(input)
if args.noise_chan > 0: if args.noise_chan > 0:
@ -275,9 +297,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss_adv, = adv_criterion(eval_out, real) loss_adv, = adv_criterion(eval_out, real)
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
if epoch >= args.adv_delay: r = loss.item() / (loss_adv.item() + 1e-8)
loss_fac = loss.item() / (loss_adv.item() + 1e-8) f = args.loss_fraction
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work? e = epoch - args.adv_epoch
d = 0.5 ** (e / args.loss_halflife)
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
@ -299,37 +323,37 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0: if batch % args.log_interval == 0:
dist.all_reduce(loss) dist.all_reduce(loss)
loss /= args.world_size loss /= world_size
if args.rank == 0: if rank == 0:
args.logger.add_scalar('loss/batch/train', loss.item(), logger.add_scalar('loss/batch/train', loss.item(),
global_step=batch) global_step=batch)
if args.adv: if args.adv:
args.logger.add_scalar('loss/batch/train/adv/G', logger.add_scalar('loss/batch/train/adv/G',
loss_adv.item(), global_step=batch) loss_adv.item(), global_step=batch)
args.logger.add_scalars('loss/batch/train/adv/D', { logger.add_scalars('loss/batch/train/adv/D', {
'total': adv_loss.item(), 'total': adv_loss.item(),
'fake': adv_loss_fake.item(), 'fake': adv_loss_fake.item(),
'real': adv_loss_real.item(), 'real': adv_loss_real.item(),
}, global_step=batch) }, global_step=batch)
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * args.world_size epoch_loss /= len(loader) * world_size
if args.rank == 0: if rank == 0:
args.logger.add_scalar('loss/epoch/train', epoch_loss[0], logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1) global_step=epoch+1)
if args.adv: if args.adv:
args.logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1], logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
global_step=epoch+1) global_step=epoch+1)
args.logger.add_scalars('loss/epoch/train/adv/D', { logger.add_scalars('loss/epoch/train/adv/D', {
'total': epoch_loss[2], 'total': epoch_loss[2],
'fake': epoch_loss[3], 'fake': epoch_loss[3],
'real': epoch_loss[4], 'real': epoch_loss[4],
}, global_step=epoch+1) }, global_step=epoch+1)
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0 skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
args.logger.add_figure('fig/epoch/train/in', logger.add_figure('fig/epoch/train/in',
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1) fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
args.logger.add_figure('fig/epoch/train/out', logger.add_figure('fig/epoch/train/out',
fig3d(output[-1, skip_chan:], target[-1, skip_chan:], fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:]), output[-1, skip_chan:] - target[-1, skip_chan:]),
global_step =epoch+1) global_step =epoch+1)
@ -337,19 +361,23 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
return epoch_loss return epoch_loss
def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
logger, device, args):
model.eval() model.eval()
if args.adv: if args.adv:
adv_model.eval() adv_model.eval()
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) rank = dist.get_rank()
fake = torch.zeros(1, dtype=torch.float32, device=args.device) world_size = dist.get_world_size()
real = torch.ones(1, dtype=torch.float32, device=args.device)
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
fake = torch.zeros(1, dtype=torch.float32, device=device)
real = torch.ones(1, dtype=torch.float32, device=device)
with torch.no_grad(): with torch.no_grad():
for input, target in loader: for input, target in loader:
input = input.to(args.device, non_blocking=True) input = input.to(device, non_blocking=True)
target = target.to(args.device, non_blocking=True) target = target.to(device, non_blocking=True)
output = model(input) output = model(input)
if args.noise_chan > 0: if args.noise_chan > 0:
@ -382,23 +410,23 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * args.world_size epoch_loss /= len(loader) * world_size
if args.rank == 0: if rank == 0:
args.logger.add_scalar('loss/epoch/val', epoch_loss[0], logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1) global_step=epoch+1)
if args.adv: if args.adv:
args.logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1], logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
global_step=epoch+1) global_step=epoch+1)
args.logger.add_scalars('loss/epoch/val/adv/D', { logger.add_scalars('loss/epoch/val/adv/D', {
'total': epoch_loss[2], 'total': epoch_loss[2],
'fake': epoch_loss[3], 'fake': epoch_loss[3],
'real': epoch_loss[4], 'real': epoch_loss[4],
}, global_step=epoch+1) }, global_step=epoch+1)
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0 skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
args.logger.add_figure('fig/epoch/val/in', logger.add_figure('fig/epoch/val/in',
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1) fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
args.logger.add_figure('fig/epoch/val', logger.add_figure('fig/epoch/val',
fig3d(output[-1, skip_chan:], target[-1, skip_chan:], fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:]), output[-1, skip_chan:] - target[-1, skip_chan:]),
global_step =epoch+1) global_step =epoch+1)