From 2d5234812bbcd868abe1d15c0f77752bceffcf36 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 15 Jul 2020 02:26:16 -0400 Subject: [PATCH] Add lagrangian and eulerian alternate training --- map2map/train.py | 138 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 39 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index 3b17ce1..f975fbe 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -118,21 +118,29 @@ def gpu_worker(local_rank, node, args): model = DistributedDataParallel(model, device_ids=[device], process_group=dist.new_group()) - dis2den = Lag2Eul() + lag2eul = Lag2Eul() criterion = import_attr(args.criterion, nn.__name__, args.callback_at) criterion = criterion() criterion.to(device) optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at) - optimizer = optimizer( + lag_optimizer = optimizer( model.parameters(), lr=args.lr, #momentum=args.momentum, - betas=(0.5, 0.999), + betas=(0.9, 0.999), weight_decay=args.weight_decay, ) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, + eul_optimizer = optimizer( + model.parameters(), + lr=args.lr, + betas=(0.9, 0.999), + weight_decay=args.weight_decay, + ) + lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer, + factor=0.1, patience=10, verbose=True) + eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer, factor=0.1, patience=10, verbose=True) if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) @@ -187,23 +195,24 @@ def gpu_worker(local_rank, node, args): for epoch in range(start_epoch, args.epochs): train_sampler.set_epoch(epoch) - train_loss = train(epoch, train_loader, - model, dis2den, criterion, optimizer, scheduler, + train_loss = train(epoch, train_loader, model, lag2eul, criterion, + lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler, logger, device, args) epoch_loss = train_loss if args.val: - val_loss = validate(epoch, val_loader, model, dis2den, criterion, + val_loss = validate(epoch, val_loader, model, lag2eul, criterion, logger, device, args) epoch_loss = val_loss if args.reduce_lr_on_plateau: - scheduler.step(epoch_loss[0]) + lag_scheduler.step(epoch_loss[0]) + eul_scheduler.step(epoch_loss[1]) if rank == 0: logger.flush() - if min_loss is None or epoch_loss[0] < min_loss[0]: + if min_loss is None or torch.prod(epoch_loss) < torch.prod(min_loss): min_loss = epoch_loss state = { @@ -224,14 +233,15 @@ def gpu_worker(local_rank, node, args): dist.destroy_process_group() -def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, +def train(epoch, loader, model, lag2eul, criterion, + lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler, logger, device, args): model.train() rank = dist.get_rank() world_size = dist.get_world_size() - epoch_loss = torch.zeros(5, dtype=torch.float64, device=device) + epoch_loss = torch.zeros(2, dtype=torch.float64, device=device) for i, (input, target) in enumerate(loader): input = input.to(device, non_blocking=True) @@ -248,52 +258,83 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) - output, target = dis2den(output, target) + lag_out, lag_tgt = output, target - loss = criterion(output, target) - epoch_loss[0] += loss.item() + if i % 2 == 0: + lag_loss = criterion(lag_out, lag_tgt) + epoch_loss[0] += lag_loss.item() - optimizer.zero_grad() - loss.backward() - optimizer.step() + with torch.no_grad(): + eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) + + eul_loss = criterion(eul_out, eul_tgt) + epoch_loss[1] += eul_loss.item() + + lag_optimizer.zero_grad() + lag_loss.backward() + lag_optimizer.step() + lag_grads = get_grads(model) + else: + with torch.no_grad(): + lag_loss = criterion(lag_out, lag_tgt) + epoch_loss[0] += lag_loss.item() + + eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) + + eul_loss = criterion(eul_out, eul_tgt) + epoch_loss[1] += eul_loss.item() + + eul_optimizer.zero_grad() + eul_loss.backward() + eul_optimizer.step() + eul_grads = get_grads(model) batch = epoch * len(loader) + i + 1 - if batch % args.log_interval == 0: - dist.all_reduce(loss) - loss /= world_size + if batch % args.log_interval == 0 and batch >= 2: + dist.all_reduce(lag_loss) + dist.all_reduce(eul_loss) + lag_loss /= world_size + eul_loss /= world_size if rank == 0: - logger.add_scalar('loss/batch/train', loss.item(), + logger.add_scalar('loss/batch/train/lag', lag_loss.item(), + global_step=batch) + logger.add_scalar('loss/batch/train/eul', eul_loss.item(), global_step=batch) - # gradients of the weights of the first and the last layer - grads = list(p.grad for n, p in model.named_parameters() - if '.weight' in n) - grads = [grads[0], grads[-1]] - grads = [g.detach().norm().item() for g in grads] - logger.add_scalar('grad/first', grads[0], global_step=batch) - logger.add_scalar('grad/last', grads[-1], global_step=batch) + logger.add_scalar('grad/lag/first', lag_grads[0], + global_step=batch) + logger.add_scalar('grad/lag/last', lag_grads[-1], + global_step=batch) + logger.add_scalar('grad/eul/first', eul_grads[0], + global_step=batch) + logger.add_scalar('grad/eul/last', eul_grads[-1], + global_step=batch) dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: - logger.add_scalar('loss/epoch/train', epoch_loss[0], + logger.add_scalar('loss/epoch/train/lag', epoch_loss[0], + global_step=epoch+1) + logger.add_scalar('loss/epoch/train/eul', epoch_loss[1], global_step=epoch+1) logger.add_figure('fig/epoch/train', plt_slices( - input[-1], output[-1], target[-1], output[-1] - target[-1], - title=['in', 'out', 'tgt', 'out - tgt'], + input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1], + eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], + title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', + 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], ), global_step=epoch+1) return epoch_loss -def validate(epoch, loader, model, dis2den, criterion, logger, device, args): +def validate(epoch, loader, model, lag2eul, criterion, logger, device, args): model.eval() rank = dist.get_rank() world_size = dist.get_world_size() - epoch_loss = torch.zeros(5, dtype=torch.float64, device=device) + epoch_loss = torch.zeros(2, dtype=torch.float64, device=device) with torch.no_grad(): for input, target in loader: @@ -307,20 +348,29 @@ def validate(epoch, loader, model, dis2den, criterion, logger, device, args): input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) - output, target = dis2den(output, target) + lag_out, lag_tgt = output, target - loss = criterion(output, target) - epoch_loss[0] += loss.item() + lag_loss = criterion(lag_out, lag_tgt) + epoch_loss[0] += lag_loss.item() + + eul_out, eul_tgt = lag2eul(lag_out, lag_tgt) + + eul_loss = criterion(eul_out, eul_tgt) + epoch_loss[1] += eul_loss.item() dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: - logger.add_scalar('loss/epoch/val', epoch_loss[0], + logger.add_scalar('loss/epoch/val/lag', epoch_loss[0], + global_step=epoch+1) + logger.add_scalar('loss/epoch/val/eul', epoch_loss[1], global_step=epoch+1) logger.add_figure('fig/epoch/val', plt_slices( - input[-1], output[-1], target[-1], output[-1] - target[-1], - title=['in', 'out', 'tgt', 'out - tgt'], + input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1], + eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], + title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', + 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], ), global_step=epoch+1) return epoch_loss @@ -363,3 +413,13 @@ def dist_init(rank, args): def set_requires_grad(module, requires_grad=False): for param in module.parameters(): param.requires_grad = requires_grad + + +def get_grads(model): + """gradients of the weights of the first and the last layer + """ + grads = list(p.grad for n, p in model.named_parameters() + if '.weight' in n) + grads = [grads[0], grads[-1]] + grads = [g.detach().norm().item() for g in grads] + return grads