Remove adversary

This commit is contained in:
Yin Li 2020-07-14 21:07:05 -04:00
parent 6fa85f8285
commit e8744d6c7b
5 changed files with 13 additions and 222 deletions

View file

@ -106,26 +106,6 @@ def add_train_args(parser):
help='multiplicative data augmentation, (log-normal) std, '
'same factor for all fields')
parser.add_argument('--adv-model', type=str,
help='enable adversary model from .models')
parser.add_argument('--adv-model-spectral-norm', action='store_true',
help='enable spectral normalization on the adversary model')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversarial criterion from torch.nn')
parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN')
parser.add_argument('--adv-start', default=0, type=int,
help='epoch to start adversarial training')
parser.add_argument('--adv-label-smoothing', default=1, type=float,
help='label of real samples for the adversary model, '
'e.g. 0.9 for label smoothing and 1 to disable')
parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)')
parser.add_argument('--instance-noise', default=0, type=float,
help='noise added to the adversary inputs to stabilize training')
parser.add_argument('--instance-noise-batches', default=1e4, type=float,
help='noise annealing duration')
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim')
parser.add_argument('--lr', type=float, required=True,
@ -134,10 +114,6 @@ def add_train_args(parser):
# help='momentum')
parser.add_argument('--weight-decay', default=0, type=float,
help='weight decay')
parser.add_argument('--adv-lr', type=float,
help='initial adversary learning rate')
parser.add_argument('--adv-weight-decay', type=float,
help='adversary weight decay')
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--init-weight-std', type=float,
@ -187,19 +163,6 @@ def set_train_args(args):
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.adv = args.adv_model is not None
if args.adv:
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
if args.cgan and not args.adv:
args.cgan =False
warnings.warn('Disabling cgan given adversary is disabled',
RuntimeWarning)
def set_test_args(args):
set_common_args(args)

View file

@ -17,10 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset
from .data.figures import plt_slices
from . import models
from .models import (narrow_like, resample, Lag2Eul,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
from .models import narrow_cast, resample, Lag2Eul
from .utils import import_attr, load_model_state_dict
@ -138,33 +135,6 @@ def gpu_worker(local_rank, node, args):
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1, patience=10, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
if args.adv:
adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1)
if args.adv_model_spectral_norm:
add_spectral_norm(adv_model)
adv_model.to(device)
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group())
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion()
adv_criterion.to(device)
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
adv_optimizer = adv_optimizer(
adv_model.parameters(),
lr=args.adv_lr,
betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay,
)
adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
@ -182,9 +152,6 @@ def gpu_worker(local_rank, node, args):
if args.init_weight_std is not None:
model.apply(init_weights)
if args.adv:
adv_model.apply(init_weights)
start_epoch = 0
if rank == 0:
@ -197,16 +164,10 @@ def gpu_worker(local_rank, node, args):
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
if args.adv and 'adv_model' in state:
load_model_state_dict(adv_model.module, state['adv_model'],
strict=args.load_state_strict)
torch.set_rng_state(state['rng'].cpu()) # move rng state back
if rank == 0:
min_loss = state['min_loss']
if args.adv and 'adv_model' not in state:
min_loss = None # restarting with adversary wipes the record
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
@ -223,35 +184,26 @@ def gpu_worker(local_rank, node, args):
pprint(vars(args))
sys.stdout.flush()
if args.adv:
args.instance_noise = InstanceNoise(args.instance_noise,
args.instance_noise_batches)
for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader,
model, dis2den, criterion, optimizer, scheduler,
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
logger, device, args)
epoch_loss = train_loss
if args.val:
val_loss = validate(epoch, val_loader,
model, dis2den, criterion, adv_model, adv_criterion,
val_loss = validate(epoch, val_loader, model, dis2den, criterion,
logger, device, args)
epoch_loss = val_loss
if args.reduce_lr_on_plateau and epoch >= args.adv_start:
if args.reduce_lr_on_plateau:
scheduler.step(epoch_loss[0])
if args.adv:
adv_scheduler.step(epoch_loss[0])
if rank == 0:
logger.flush()
if ((min_loss is None or epoch_loss[0] < min_loss[0])
and epoch >= args.adv_start):
if min_loss is None or epoch_loss[0] < min_loss[0]:
min_loss = epoch_loss
state = {
@ -260,8 +212,6 @@ def gpu_worker(local_rank, node, args):
'rng': torch.get_rng_state(),
'min_loss': min_loss,
}
if args.adv:
state['adv_model'] = adv_model.module.state_dict()
state_file = 'state_{}.pt'.format(epoch + 1)
torch.save(state, state_file)
@ -275,24 +225,13 @@ def gpu_worker(local_rank, node, args):
def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
logger, device, args):
model.train()
if args.adv:
adv_model.train()
rank = dist.get_rank()
world_size = dist.get_world_size()
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
# loss: generator (model) supervised loss
# loss_adv: generator (model) adversarial loss
# adv_loss: discriminator (adv_model) loss
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
fake = torch.zeros([1], dtype=torch.float32, device=device)
real = torch.ones([1], dtype=torch.float32, device=device)
adv_real = torch.full([1], args.adv_label_smoothing, dtype=torch.float32,
device=device)
for i, (input, target) in enumerate(loader):
input = input.to(device, non_blocking=True)
@ -314,45 +253,6 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start:
noise_std = args.instance_noise.std()
if noise_std > 0:
noise = noise_std * torch.randn_like(output)
output = output + noise.detach()
noise = noise_std * torch.randn_like(target)
target = target + noise.detach()
del noise
if args.cgan:
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
# discriminator
set_requires_grad(adv_model, True)
eval = adv_model([output.detach(), target])
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, adv_real])
epoch_loss[3] += adv_loss_fake.item()
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss.backward()
adv_optimizer.step()
# generator adversarial loss
set_requires_grad(adv_model, False)
eval_out = adv_model(output)
loss_adv, = adv_criterion(eval_out, real)
epoch_loss[1] += loss_adv.item()
ratio = loss.item() / (loss_adv.item() + 1e-8)
frac = args.loss_fraction
if epoch >= args.adv_start:
loss = frac * loss + (1 - frac) * ratio * loss_adv
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -364,14 +264,6 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
if rank == 0:
logger.add_scalar('loss/batch/train', loss.item(),
global_step=batch)
if args.adv and epoch >= args.adv_start:
logger.add_scalar('loss/batch/train/adv/G',
loss_adv.item(), global_step=batch)
logger.add_scalars('loss/batch/train/adv/D', {
'total': adv_loss.item(),
'fake': adv_loss_fake.item(),
'real': adv_loss_real.item(),
}, global_step=batch)
# gradients of the weights of the first and the last layer
grads = list(p.grad for n, p in model.named_parameters()
@ -380,60 +272,28 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
grads = [g.detach().norm().item() for g in grads]
logger.add_scalar('grad/first', grads[0], global_step=batch)
logger.add_scalar('grad/last', grads[-1], global_step=batch)
if args.adv and epoch >= args.adv_start:
grads = list(p.grad for n, p in adv_model.named_parameters()
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad/adv/first', grads[0],
global_step=batch)
logger.add_scalars('grad/adv/last', grads[-1],
global_step=batch)
if args.adv and epoch >= args.adv_start and noise_std > 0:
logger.add_scalar('instance_noise', noise_std,
global_step=batch)
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1)
if args.adv and epoch >= args.adv_start:
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
global_step=epoch+1)
logger.add_scalars('loss/epoch/train/adv/D', {
'total': epoch_loss[2],
'fake': epoch_loss[3],
'real': epoch_loss[4],
}, global_step=epoch+1)
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/train', plt_slices(
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:],
input[-1], output[-1], target[-1], output[-1] - target[-1],
title=['in', 'out', 'tgt', 'out - tgt'],
), global_step=epoch+1)
return epoch_loss
def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
logger, device, args):
def validate(epoch, loader, model, dis2den, criterion, logger, device, args):
model.eval()
if args.adv:
adv_model.eval()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
fake = torch.zeros([1], dtype=torch.float32, device=device)
real = torch.ones([1], dtype=torch.float32, device=device)
with torch.no_grad():
for input, target in loader:
@ -452,46 +312,14 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start:
if args.cgan:
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
# discriminator
eval = adv_model([output, target])
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
epoch_loss[3] += adv_loss_fake.item()
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
# generator adversarial loss
eval_out, _ = adv_criterion.split_input(eval, [fake, real])
loss_adv, = adv_criterion(eval_out, real)
epoch_loss[1] += loss_adv.item()
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1)
if args.adv and epoch >= args.adv_start:
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
global_step=epoch+1)
logger.add_scalars('loss/epoch/val/adv/D', {
'total': epoch_loss[2],
'fake': epoch_loss[3],
'real': epoch_loss[4],
}, global_step=epoch+1)
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/val', plt_slices(
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
output[-1, skip_chan:] - target[-1, skip_chan:],
input[-1], output[-1], target[-1], output[-1] - target[-1],
title=['in', 'out', 'tgt', 'out - tgt'],
), global_step=epoch+1)

View file

@ -37,8 +37,8 @@ srun m2m.py train \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
--model VNet --adv-model UNet --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--model VNet \
--lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM

View file

@ -37,8 +37,8 @@ srun m2m.py train \
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
--model VNet --adv-model PatchGAN --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--model VNet \
--lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM

View file

@ -37,8 +37,8 @@ srun m2m.py train \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
--model VNet --adv-model UNet --cgan \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--model VNet \
--lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM