From 352482d4752fb9d69eb429dd7226b235a650f901 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 8 Jun 2021 18:28:45 -0400 Subject: [PATCH] Remove Lagrangian part --- map2map/train.py | 81 ++++++++++++++++-------------------------------- 1 file changed, 27 insertions(+), 54 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index b2aeb80..d9690d8 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -14,7 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset, DistFieldSampler from . import models -from .models import narrow_cast, resample, lag2eul +from .models import narrow_cast, resample from .utils import import_attr, load_model_state_dict, plt_slices, plt_power @@ -267,36 +267,20 @@ def train(epoch, loader, model, criterion, if batch <= 5 and rank == 0: print('narrowed shape :', output.shape) - lag_out, lag_tgt = output, target - eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) - if batch <= 5 and rank == 0: - print('Eulerian shape :', eul_out.shape, flush=True) + loss = criterion(output, target) + epoch_loss[0] += loss.detach() - lag_loss = criterion(lag_out, lag_tgt) - eul_loss = criterion(eul_out, eul_tgt) - loss = lag_loss * eul_loss - epoch_loss[0] += lag_loss.detach() - epoch_loss[1] += eul_loss.detach() - epoch_loss[2] += loss.detach() optimizer.zero_grad() - torch.log(loss).backward() # NOTE actual loss is log(loss) + loss.backward() optimizer.step() grads = get_grads(model) if batch % args.log_interval == 0: - dist.all_reduce(lag_loss) - dist.all_reduce(eul_loss) dist.all_reduce(loss) - lag_loss /= world_size - eul_loss /= world_size loss /= world_size if rank == 0: - logger.add_scalar('loss/batch/train/lag', lag_loss.item(), - global_step=batch) - logger.add_scalar('loss/batch/train/eul', eul_loss.item(), - global_step=batch) - logger.add_scalar('loss/batch/train/lxe', loss.item(), + logger.add_scalar('loss/batch/train', loss.item(), global_step=batch) logger.add_scalar('grad/first', grads[0], global_step=batch) @@ -305,30 +289,28 @@ def train(epoch, loader, model, criterion, dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: - logger.add_scalar('loss/epoch/train/lag', epoch_loss[0], - global_step=epoch+1) - logger.add_scalar('loss/epoch/train/eul', epoch_loss[1], - global_step=epoch+1) - logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2], + logger.add_scalar('loss/epoch/train', epoch_loss[0], global_step=epoch+1) fig = plt_slices( - input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1], - eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], - title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', - 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], + input[-1], output[-1, skip_chan:], target[-1, skip_chan:], + output[-1, skip_chan:] - target[-1, skip_chan:], + title=['in', 'out', 'tgt', 'out - tgt'], **args.misc_kwargs, ) logger.add_figure('fig/train', fig, global_step=epoch+1) fig.clf() - fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], - **args.misc_kwargs) + fig = plt_power( + input, output[:, skip_chan:], target[:, skip_chan:], + label=['in', 'out', 'tgt'], + **args.misc_kwargs, + ) logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) fig.clf() #fig = plt_power(1.0, - # dis=[input, lag_out, lag_tgt], + # dis=[input, output[:, skip_chan:], target[:, skip_chan:]], # label=['in', 'out', 'tgt'], # **args.misc_kwargs, #) @@ -361,43 +343,34 @@ def validate(epoch, loader, model, criterion, logger, device, args): input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) - lag_out, lag_tgt = output, target - eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) - - lag_loss = criterion(lag_out, lag_tgt) - eul_loss = criterion(eul_out, eul_tgt) - loss = lag_loss * eul_loss - epoch_loss[0] += lag_loss.detach() - epoch_loss[1] += eul_loss.detach() - epoch_loss[2] += loss.detach() + loss = criterion(output, target) + epoch_loss[0] += loss.detach() dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size if rank == 0: - logger.add_scalar('loss/epoch/val/lag', epoch_loss[0], - global_step=epoch+1) - logger.add_scalar('loss/epoch/val/eul', epoch_loss[1], - global_step=epoch+1) - logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2], + logger.add_scalar('loss/epoch/val', epoch_loss[0], global_step=epoch+1) fig = plt_slices( - input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1], - eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], - title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', - 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], + input[-1], output[-1, skip_chan:], target[-1, skip_chan:], + output[-1, skip_chan:] - target[-1, skip_chan:], + title=['in', 'out', 'tgt', 'out - tgt'], **args.misc_kwargs, ) logger.add_figure('fig/val', fig, global_step=epoch+1) fig.clf() - fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], - **args.misc_kwargs) + fig = plt_power( + input, output[:, skip_chan:], target[:, skip_chan:], + label=['in', 'out', 'tgt'], + **args.misc_kwargs, + ) logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) fig.clf() #fig = plt_power(1.0, - # dis=[input, lag_out, lag_tgt], + # dis=[input, output[:, skip_chan:], target[:, skip_chan:]], # label=['in', 'out', 'tgt'], # **args.misc_kwargs, #)