diff --git a/map2map/args.py b/map2map/args.py index df4c242..88d2d48 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -106,26 +106,6 @@ def add_train_args(parser): help='multiplicative data augmentation, (log-normal) std, ' 'same factor for all fields') - parser.add_argument('--adv-model', type=str, - help='enable adversary model from .models') - parser.add_argument('--adv-model-spectral-norm', action='store_true', - help='enable spectral normalization on the adversary model') - parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str, - help='adversarial criterion from torch.nn') - parser.add_argument('--cgan', action='store_true', - help='enable conditional GAN') - parser.add_argument('--adv-start', default=0, type=int, - help='epoch to start adversarial training') - parser.add_argument('--adv-label-smoothing', default=1, type=float, - help='label of real samples for the adversary model, ' - 'e.g. 0.9 for label smoothing and 1 to disable') - parser.add_argument('--loss-fraction', default=0.5, type=float, - help='final fraction of loss (vs adv-loss)') - parser.add_argument('--instance-noise', default=0, type=float, - help='noise added to the adversary inputs to stabilize training') - parser.add_argument('--instance-noise-batches', default=1e4, type=float, - help='noise annealing duration') - parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer from torch.optim') parser.add_argument('--lr', type=float, required=True, @@ -134,10 +114,6 @@ def add_train_args(parser): # help='momentum') parser.add_argument('--weight-decay', default=0, type=float, help='weight decay') - parser.add_argument('--adv-lr', type=float, - help='initial adversary learning rate') - parser.add_argument('--adv-weight-decay', type=float, - help='adversary weight decay') parser.add_argument('--reduce-lr-on-plateau', action='store_true', help='Enable ReduceLROnPlateau learning rate scheduler') parser.add_argument('--init-weight-std', type=float, @@ -187,19 +163,6 @@ def set_train_args(args): args.val = args.val_in_patterns is not None and \ args.val_tgt_patterns is not None - args.adv = args.adv_model is not None - - if args.adv: - if args.adv_lr is None: - args.adv_lr = args.lr - if args.adv_weight_decay is None: - args.adv_weight_decay = args.weight_decay - - if args.cgan and not args.adv: - args.cgan =False - warnings.warn('Disabling cgan given adversary is disabled', - RuntimeWarning) - def set_test_args(args): set_common_args(args) diff --git a/map2map/train.py b/map2map/train.py index 56d028e..3b17ce1 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -17,10 +17,7 @@ from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset from .data.figures import plt_slices from . import models -from .models import (narrow_like, resample, Lag2Eul, - adv_model_wrapper, adv_criterion_wrapper, - add_spectral_norm, rm_spectral_norm, - InstanceNoise) +from .models import narrow_cast, resample, Lag2Eul from .utils import import_attr, load_model_state_dict @@ -138,33 +135,6 @@ def gpu_worker(local_rank, node, args): scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10, verbose=True) - adv_model = adv_criterion = adv_optimizer = adv_scheduler = None - if args.adv: - adv_model = import_attr(args.adv_model, models.__name__, args.callback_at) - adv_model = adv_model_wrapper(adv_model) - adv_model = adv_model(sum(args.in_chan + args.out_chan) - if args.cgan else sum(args.out_chan), 1) - if args.adv_model_spectral_norm: - add_spectral_norm(adv_model) - adv_model.to(device) - adv_model = DistributedDataParallel(adv_model, device_ids=[device], - process_group=dist.new_group()) - - adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at) - adv_criterion = adv_criterion_wrapper(adv_criterion) - adv_criterion = adv_criterion() - adv_criterion.to(device) - - adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at) - adv_optimizer = adv_optimizer( - adv_model.parameters(), - lr=args.adv_lr, - betas=(0.5, 0.999), - weight_decay=args.adv_weight_decay, - ) - adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer, - factor=0.1, patience=10, verbose=True) - if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) or not args.load_state): def init_weights(m): @@ -182,9 +152,6 @@ def gpu_worker(local_rank, node, args): if args.init_weight_std is not None: model.apply(init_weights) - if args.adv: - adv_model.apply(init_weights) - start_epoch = 0 if rank == 0: @@ -197,16 +164,10 @@ def gpu_worker(local_rank, node, args): load_model_state_dict(model.module, state['model'], strict=args.load_state_strict) - if args.adv and 'adv_model' in state: - load_model_state_dict(adv_model.module, state['adv_model'], - strict=args.load_state_strict) - torch.set_rng_state(state['rng'].cpu()) # move rng state back if rank == 0: min_loss = state['min_loss'] - if args.adv and 'adv_model' not in state: - min_loss = None # restarting with adversary wipes the record print('state at epoch {} loaded from {}'.format( state['epoch'], args.load_state), flush=True) @@ -223,35 +184,26 @@ def gpu_worker(local_rank, node, args): pprint(vars(args)) sys.stdout.flush() - if args.adv: - args.instance_noise = InstanceNoise(args.instance_noise, - args.instance_noise_batches) - for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) train_loss = train(epoch, train_loader, model, dis2den, criterion, optimizer, scheduler, - adv_model, adv_criterion, adv_optimizer, adv_scheduler, logger, device, args) epoch_loss = train_loss if args.val: - val_loss = validate(epoch, val_loader, - model, dis2den, criterion, adv_model, adv_criterion, + val_loss = validate(epoch, val_loader, model, dis2den, criterion, logger, device, args) epoch_loss = val_loss - if args.reduce_lr_on_plateau and epoch >= args.adv_start: + if args.reduce_lr_on_plateau: scheduler.step(epoch_loss[0]) - if args.adv: - adv_scheduler.step(epoch_loss[0]) if rank == 0: logger.flush() - if ((min_loss is None or epoch_loss[0] < min_loss[0]) - and epoch >= args.adv_start): + if min_loss is None or epoch_loss[0] < min_loss[0]: min_loss = epoch_loss state = { @@ -260,8 +212,6 @@ def gpu_worker(local_rank, node, args): 'rng': torch.get_rng_state(), 'min_loss': min_loss, } - if args.adv: - state['adv_model'] = adv_model.module.state_dict() state_file = 'state_{}.pt'.format(epoch + 1) torch.save(state, state_file) @@ -275,24 +225,13 @@ def gpu_worker(local_rank, node, args): def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, - adv_model, adv_criterion, adv_optimizer, adv_scheduler, logger, device, args): model.train() - if args.adv: - adv_model.train() rank = dist.get_rank() world_size = dist.get_world_size() - # loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real - # loss: generator (model) supervised loss - # loss_adv: generator (model) adversarial loss - # adv_loss: discriminator (adv_model) loss epoch_loss = torch.zeros(5, dtype=torch.float64, device=device) - fake = torch.zeros([1], dtype=torch.float32, device=device) - real = torch.ones([1], dtype=torch.float32, device=device) - adv_real = torch.full([1], args.adv_label_smoothing, dtype=torch.float32, - device=device) for i, (input, target) in enumerate(loader): input = input.to(device, non_blocking=True) @@ -314,45 +253,6 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, loss = criterion(output, target) epoch_loss[0] += loss.item() - if args.adv and epoch >= args.adv_start: - noise_std = args.instance_noise.std() - if noise_std > 0: - noise = noise_std * torch.randn_like(output) - output = output + noise.detach() - noise = noise_std * torch.randn_like(target) - target = target + noise.detach() - del noise - - if args.cgan: - output = torch.cat([input, output], dim=1) - target = torch.cat([input, target], dim=1) - - # discriminator - set_requires_grad(adv_model, True) - - eval = adv_model([output.detach(), target]) - adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, adv_real]) - epoch_loss[3] += adv_loss_fake.item() - epoch_loss[4] += adv_loss_real.item() - adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) - epoch_loss[2] += adv_loss.item() - - adv_optimizer.zero_grad() - adv_loss.backward() - adv_optimizer.step() - - # generator adversarial loss - set_requires_grad(adv_model, False) - - eval_out = adv_model(output) - loss_adv, = adv_criterion(eval_out, real) - epoch_loss[1] += loss_adv.item() - - ratio = loss.item() / (loss_adv.item() + 1e-8) - frac = args.loss_fraction - if epoch >= args.adv_start: - loss = frac * loss + (1 - frac) * ratio * loss_adv - optimizer.zero_grad() loss.backward() optimizer.step() @@ -364,14 +264,6 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, if rank == 0: logger.add_scalar('loss/batch/train', loss.item(), global_step=batch) - if args.adv and epoch >= args.adv_start: - logger.add_scalar('loss/batch/train/adv/G', - loss_adv.item(), global_step=batch) - logger.add_scalars('loss/batch/train/adv/D', { - 'total': adv_loss.item(), - 'fake': adv_loss_fake.item(), - 'real': adv_loss_real.item(), - }, global_step=batch) # gradients of the weights of the first and the last layer grads = list(p.grad for n, p in model.named_parameters() @@ -380,60 +272,28 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, grads = [g.detach().norm().item() for g in grads] logger.add_scalar('grad/first', grads[0], global_step=batch) logger.add_scalar('grad/last', grads[-1], global_step=batch) - if args.adv and epoch >= args.adv_start: - grads = list(p.grad for n, p in adv_model.named_parameters() - if '.weight' in n) - grads = [grads[0], grads[-1]] - grads = [g.detach().norm().item() for g in grads] - logger.add_scalars('grad/adv/first', grads[0], - global_step=batch) - logger.add_scalars('grad/adv/last', grads[-1], - global_step=batch) - - if args.adv and epoch >= args.adv_start and noise_std > 0: - logger.add_scalar('instance_noise', noise_std, - global_step=batch) dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: logger.add_scalar('loss/epoch/train', epoch_loss[0], global_step=epoch+1) - if args.adv and epoch >= args.adv_start: - logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1], - global_step=epoch+1) - logger.add_scalars('loss/epoch/train/adv/D', { - 'total': epoch_loss[2], - 'fake': epoch_loss[3], - 'real': epoch_loss[4], - }, global_step=epoch+1) - skip_chan = 0 - if args.adv and epoch >= args.adv_start and args.cgan: - skip_chan = sum(args.in_chan) logger.add_figure('fig/epoch/train', plt_slices( - input[-1], - output[-1, skip_chan:], - target[-1, skip_chan:], - output[-1, skip_chan:] - target[-1, skip_chan:], + input[-1], output[-1], target[-1], output[-1] - target[-1], title=['in', 'out', 'tgt', 'out - tgt'], ), global_step=epoch+1) return epoch_loss -def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion, - logger, device, args): +def validate(epoch, loader, model, dis2den, criterion, logger, device, args): model.eval() - if args.adv: - adv_model.eval() rank = dist.get_rank() world_size = dist.get_world_size() epoch_loss = torch.zeros(5, dtype=torch.float64, device=device) - fake = torch.zeros([1], dtype=torch.float32, device=device) - real = torch.ones([1], dtype=torch.float32, device=device) with torch.no_grad(): for input, target in loader: @@ -452,46 +312,14 @@ def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion, loss = criterion(output, target) epoch_loss[0] += loss.item() - if args.adv and epoch >= args.adv_start: - if args.cgan: - output = torch.cat([input, output], dim=1) - target = torch.cat([input, target], dim=1) - - # discriminator - eval = adv_model([output, target]) - adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real]) - epoch_loss[3] += adv_loss_fake.item() - epoch_loss[4] += adv_loss_real.item() - adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) - epoch_loss[2] += adv_loss.item() - - # generator adversarial loss - eval_out, _ = adv_criterion.split_input(eval, [fake, real]) - loss_adv, = adv_criterion(eval_out, real) - epoch_loss[1] += loss_adv.item() - dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: logger.add_scalar('loss/epoch/val', epoch_loss[0], global_step=epoch+1) - if args.adv and epoch >= args.adv_start: - logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1], - global_step=epoch+1) - logger.add_scalars('loss/epoch/val/adv/D', { - 'total': epoch_loss[2], - 'fake': epoch_loss[3], - 'real': epoch_loss[4], - }, global_step=epoch+1) - skip_chan = 0 - if args.adv and epoch >= args.adv_start and args.cgan: - skip_chan = sum(args.in_chan) logger.add_figure('fig/epoch/val', plt_slices( - input[-1], - output[-1, skip_chan:], - target[-1, skip_chan:], - output[-1, skip_chan:] - target[-1, skip_chan:], + input[-1], output[-1], target[-1], output[-1] - target[-1], title=['in', 'out', 'tgt', 'out - tgt'], ), global_step=epoch+1) diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm index ed39a15..cf5def1 100644 --- a/scripts/dis2dis.slurm +++ b/scripts/dis2dis.slurm @@ -37,8 +37,8 @@ srun m2m.py train \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --in-norms cosmology.dis --tgt-norms cosmology.dis --augment --crop 128 --pad 20 \ - --model VNet --adv-model UNet --cgan \ - --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ + --model VNet \ + --lr 0.0001 --batches 1 --loader-workers 0 \ --epochs 1024 --seed $RANDOM diff --git a/scripts/srsgan.slurm b/scripts/srsgan.slurm index 5de73cc..6480f44 100644 --- a/scripts/srsgan.slurm +++ b/scripts/srsgan.slurm @@ -37,8 +37,8 @@ srun m2m.py train \ --train-in-patterns "$data_root_dir/$in_dir/$train_dirs/$in_files_1,$data_root_dir/$in_dir/$train_dirs/$in_files_2" \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$train_dirs/$tgt_files_2" \ --in-norms cosmology.dis,cosmology.vel --tgt-norms cosmology.dis,cosmology.vel --augment --crop 88 --pad 20 --scale-factor 2 \ - --model VNet --adv-model PatchGAN --cgan \ - --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ + --model VNet \ + --lr 0.0001 --batches 1 --loader-workers 0 \ --epochs 1024 --seed $RANDOM diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm index ac046fd..14e96a0 100644 --- a/scripts/vel2vel.slurm +++ b/scripts/vel2vel.slurm @@ -37,8 +37,8 @@ srun m2m.py train \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --in-norms cosmology.vel --tgt-norms cosmology.vel --augment --crop 128 --pad 20 \ - --model VNet --adv-model UNet --cgan \ - --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ + --model VNet \ + --lr 0.0001 --batches 1 --loader-workers 0 \ --epochs 1024 --seed $RANDOM