Remove adversary
This commit is contained in:
parent
607bcf3f4c
commit
337d65de68
@ -106,26 +106,6 @@ def add_train_args(parser):
|
||||
help='multiplicative data augmentation, (log-normal) std, '
|
||||
'same factor for all fields')
|
||||
|
||||
parser.add_argument('--adv-model', type=str,
|
||||
help='enable adversary model from .models')
|
||||
parser.add_argument('--adv-model-spectral-norm', action='store_true',
|
||||
help='enable spectral normalization on the adversary model')
|
||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||
help='adversarial criterion from torch.nn')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--adv-start', default=0, type=int,
|
||||
help='epoch to start adversarial training')
|
||||
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
||||
help='label of real samples for the adversary model, '
|
||||
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||
help='final fraction of loss (vs adv-loss)')
|
||||
parser.add_argument('--instance-noise', default=0, type=float,
|
||||
help='noise added to the adversary inputs to stabilize training')
|
||||
parser.add_argument('--instance-noise-batches', default=1e4, type=float,
|
||||
help='noise annealing duration')
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
parser.add_argument('--lr', type=float, required=True,
|
||||
@ -134,10 +114,6 @@ def add_train_args(parser):
|
||||
# help='momentum')
|
||||
parser.add_argument('--weight-decay', default=0, type=float,
|
||||
help='weight decay')
|
||||
parser.add_argument('--adv-lr', type=float,
|
||||
help='initial adversary learning rate')
|
||||
parser.add_argument('--adv-weight-decay', type=float,
|
||||
help='adversary weight decay')
|
||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||
parser.add_argument('--init-weight-std', type=float,
|
||||
@ -187,19 +163,6 @@ def set_train_args(args):
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
|
||||
args.adv = args.adv_model is not None
|
||||
|
||||
if args.adv:
|
||||
if args.adv_lr is None:
|
||||
args.adv_lr = args.lr
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
if args.cgan and not args.adv:
|
||||
args.cgan =False
|
||||
warnings.warn('Disabling cgan given adversary is disabled',
|
||||
RuntimeWarning)
|
||||
|
||||
|
||||
def set_test_args(args):
|
||||
set_common_args(args)
|
||||
|
186
map2map/train.py
186
map2map/train.py
@ -17,10 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from .data import FieldDataset
|
||||
from .data.figures import plt_slices
|
||||
from . import models
|
||||
from .models import (narrow_cast, resample, Lag2Eul
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm,
|
||||
InstanceNoise)
|
||||
from .models import narrow_cast, resample, Lag2Eul
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
|
||||
|
||||
@ -138,33 +135,6 @@ def gpu_worker(local_rank, node, args):
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
|
||||
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
||||
if args.adv:
|
||||
adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
|
||||
adv_model = adv_model_wrapper(adv_model)
|
||||
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
||||
if args.cgan else sum(args.out_chan), 1)
|
||||
if args.adv_model_spectral_norm:
|
||||
add_spectral_norm(adv_model)
|
||||
adv_model.to(device)
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||
adv_criterion = adv_criterion()
|
||||
adv_criterion.to(device)
|
||||
|
||||
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||
adv_optimizer = adv_optimizer(
|
||||
adv_model.parameters(),
|
||||
lr=args.adv_lr,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=args.adv_weight_decay,
|
||||
)
|
||||
adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state):
|
||||
def init_weights(m):
|
||||
@ -182,9 +152,6 @@ def gpu_worker(local_rank, node, args):
|
||||
if args.init_weight_std is not None:
|
||||
model.apply(init_weights)
|
||||
|
||||
if args.adv:
|
||||
adv_model.apply(init_weights)
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
if rank == 0:
|
||||
@ -197,16 +164,10 @@ def gpu_worker(local_rank, node, args):
|
||||
load_model_state_dict(model.module, state['model'],
|
||||
strict=args.load_state_strict)
|
||||
|
||||
if args.adv and 'adv_model' in state:
|
||||
load_model_state_dict(adv_model.module, state['adv_model'],
|
||||
strict=args.load_state_strict)
|
||||
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
|
||||
if rank == 0:
|
||||
min_loss = state['min_loss']
|
||||
if args.adv and 'adv_model' not in state:
|
||||
min_loss = None # restarting with adversary wipes the record
|
||||
|
||||
print('state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state), flush=True)
|
||||
@ -223,35 +184,26 @@ def gpu_worker(local_rank, node, args):
|
||||
pprint(vars(args))
|
||||
sys.stdout.flush()
|
||||
|
||||
if args.adv:
|
||||
args.instance_noise = InstanceNoise(args.instance_noise,
|
||||
args.instance_noise_batches)
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_loss = train(epoch, train_loader,
|
||||
model, dis2den, criterion, optimizer, scheduler,
|
||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||
logger, device, args)
|
||||
epoch_loss = train_loss
|
||||
|
||||
if args.val:
|
||||
val_loss = validate(epoch, val_loader,
|
||||
model, dis2den, criterion, adv_model, adv_criterion,
|
||||
val_loss = validate(epoch, val_loader, model, dis2den, criterion,
|
||||
logger, device, args)
|
||||
epoch_loss = val_loss
|
||||
|
||||
if args.reduce_lr_on_plateau and epoch >= args.adv_start:
|
||||
if args.reduce_lr_on_plateau:
|
||||
scheduler.step(epoch_loss[0])
|
||||
if args.adv:
|
||||
adv_scheduler.step(epoch_loss[0])
|
||||
|
||||
if rank == 0:
|
||||
logger.flush()
|
||||
|
||||
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
||||
and epoch >= args.adv_start):
|
||||
if min_loss is None or epoch_loss[0] < min_loss[0]:
|
||||
min_loss = epoch_loss
|
||||
|
||||
state = {
|
||||
@ -260,8 +212,6 @@ def gpu_worker(local_rank, node, args):
|
||||
'rng': torch.get_rng_state(),
|
||||
'min_loss': min_loss,
|
||||
}
|
||||
if args.adv:
|
||||
state['adv_model'] = adv_model.module.state_dict()
|
||||
|
||||
state_file = 'state_{}.pt'.format(epoch + 1)
|
||||
torch.save(state, state_file)
|
||||
@ -275,24 +225,13 @@ def gpu_worker(local_rank, node, args):
|
||||
|
||||
|
||||
def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||
logger, device, args):
|
||||
model.train()
|
||||
if args.adv:
|
||||
adv_model.train()
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||
# loss: generator (model) supervised loss
|
||||
# loss_adv: generator (model) adversarial loss
|
||||
# adv_loss: discriminator (adv_model) loss
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
||||
fake = torch.zeros([1], dtype=torch.float32, device=device)
|
||||
real = torch.ones([1], dtype=torch.float32, device=device)
|
||||
adv_real = torch.full([1], args.adv_label_smoothing, dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
input = input.to(device, non_blocking=True)
|
||||
@ -314,45 +253,6 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
noise_std = args.instance_noise.std()
|
||||
if noise_std > 0:
|
||||
noise = noise_std * torch.randn_like(output)
|
||||
output = output + noise.detach()
|
||||
noise = noise_std * torch.randn_like(target)
|
||||
target = target + noise.detach()
|
||||
del noise
|
||||
|
||||
if args.cgan:
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
set_requires_grad(adv_model, True)
|
||||
|
||||
eval = adv_model([output.detach(), target])
|
||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, adv_real])
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
# generator adversarial loss
|
||||
set_requires_grad(adv_model, False)
|
||||
|
||||
eval_out = adv_model(output)
|
||||
loss_adv, = adv_criterion(eval_out, real)
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
ratio = loss.item() / (loss_adv.item() + 1e-8)
|
||||
frac = args.loss_fraction
|
||||
if epoch >= args.adv_start:
|
||||
loss = frac * loss + (1 - frac) * ratio * loss_adv
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@ -364,14 +264,6 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/batch/train', loss.item(),
|
||||
global_step=batch)
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('loss/batch/train/adv/G',
|
||||
loss_adv.item(), global_step=batch)
|
||||
logger.add_scalars('loss/batch/train/adv/D', {
|
||||
'total': adv_loss.item(),
|
||||
'fake': adv_loss_fake.item(),
|
||||
'real': adv_loss_real.item(),
|
||||
}, global_step=batch)
|
||||
|
||||
# gradients of the weights of the first and the last layer
|
||||
grads = list(p.grad for n, p in model.named_parameters()
|
||||
@ -380,60 +272,28 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
||||
grads = [g.detach().norm().item() for g in grads]
|
||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
grads = list(p.grad for n, p in adv_model.named_parameters()
|
||||
if '.weight' in n)
|
||||
grads = [grads[0], grads[-1]]
|
||||
grads = [g.detach().norm().item() for g in grads]
|
||||
logger.add_scalars('grad/adv/first', grads[0],
|
||||
global_step=batch)
|
||||
logger.add_scalars('grad/adv/last', grads[-1],
|
||||
global_step=batch)
|
||||
|
||||
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
||||
logger.add_scalar('instance_noise', noise_std,
|
||||
global_step=batch)
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalars('loss/epoch/train/adv/D', {
|
||||
'total': epoch_loss[2],
|
||||
'fake': epoch_loss[3],
|
||||
'real': epoch_loss[4],
|
||||
}, global_step=epoch+1)
|
||||
|
||||
skip_chan = 0
|
||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||
skip_chan = sum(args.in_chan)
|
||||
logger.add_figure('fig/epoch/train', plt_slices(
|
||||
input[-1],
|
||||
output[-1, skip_chan:],
|
||||
target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
input[-1], output[-1], target[-1], output[-1] - target[-1],
|
||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||
), global_step=epoch+1)
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
|
||||
logger, device, args):
|
||||
def validate(epoch, loader, model, dis2den, criterion, logger, device, args):
|
||||
model.eval()
|
||||
if args.adv:
|
||||
adv_model.eval()
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
||||
fake = torch.zeros([1], dtype=torch.float32, device=device)
|
||||
real = torch.ones([1], dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
for input, target in loader:
|
||||
@ -452,46 +312,14 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
if args.cgan:
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
eval = adv_model([output, target])
|
||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
eval_out, _ = adv_criterion.split_input(eval, [fake, real])
|
||||
loss_adv, = adv_criterion(eval_out, real)
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalars('loss/epoch/val/adv/D', {
|
||||
'total': epoch_loss[2],
|
||||
'fake': epoch_loss[3],
|
||||
'real': epoch_loss[4],
|
||||
}, global_step=epoch+1)
|
||||
|
||||
skip_chan = 0
|
||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||
skip_chan = sum(args.in_chan)
|
||||
logger.add_figure('fig/epoch/val', plt_slices(
|
||||
input[-1],
|
||||
output[-1, skip_chan:],
|
||||
target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
input[-1], output[-1], target[-1], output[-1] - target[-1],
|
||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||
), global_step=epoch+1)
|
||||
|
||||
|
@ -37,8 +37,8 @@ srun m2m.py train \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \
|
||||
--model VNet --adv-model UNet --cgan \
|
||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||
--model VNet \
|
||||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||
--epochs 1024 --seed $RANDOM
|
||||
|
||||
|
||||
|
@ -37,8 +37,8 @@ srun m2m.py train \
|
||||
--train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \
|
||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \
|
||||
--in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \
|
||||
--model VNet --adv-model PatchGAN --cgan \
|
||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||
--model VNet \
|
||||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||
--epochs 1024 --seed $RANDOM
|
||||
|
||||
|
||||
|
@ -37,8 +37,8 @@ srun m2m.py train \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \
|
||||
--model VNet --adv-model UNet --cgan \
|
||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||
--model VNet \
|
||||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||
--epochs 1024 --seed $RANDOM
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user