Remove Lagrangian part

This commit is contained in:
Yin Li 2021-06-08 18:28:45 -04:00
parent 0d4ae3424e
commit 352482d475

View File

@ -14,7 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler from .data import FieldDataset, DistFieldSampler
from . import models from . import models
from .models import narrow_cast, resample, lag2eul from .models import narrow_cast, resample
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
@ -267,36 +267,20 @@ def train(epoch, loader, model, criterion,
if batch <= 5 and rank == 0: if batch <= 5 and rank == 0:
print('narrowed shape :', output.shape) print('narrowed shape :', output.shape)
lag_out, lag_tgt = output, target loss = criterion(output, target)
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) epoch_loss[0] += loss.detach()
if batch <= 5 and rank == 0:
print('Eulerian shape :', eul_out.shape, flush=True)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.detach()
epoch_loss[1] += eul_loss.detach()
epoch_loss[2] += loss.detach()
optimizer.zero_grad() optimizer.zero_grad()
torch.log(loss).backward() # NOTE actual loss is log(loss) loss.backward()
optimizer.step() optimizer.step()
grads = get_grads(model) grads = get_grads(model)
if batch % args.log_interval == 0: if batch % args.log_interval == 0:
dist.all_reduce(lag_loss)
dist.all_reduce(eul_loss)
dist.all_reduce(loss) dist.all_reduce(loss)
lag_loss /= world_size
eul_loss /= world_size
loss /= world_size loss /= world_size
if rank == 0: if rank == 0:
logger.add_scalar('loss/batch/train/lag', lag_loss.item(), logger.add_scalar('loss/batch/train', loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/lxe', loss.item(),
global_step=batch) global_step=batch)
logger.add_scalar('grad/first', grads[0], global_step=batch) logger.add_scalar('grad/first', grads[0], global_step=batch)
@ -305,30 +289,28 @@ def train(epoch, loader, model, criterion,
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size epoch_loss /= len(loader) * world_size
if rank == 0: if rank == 0:
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0], logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2],
global_step=epoch+1) global_step=epoch+1)
fig = plt_slices( fig = plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1], input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], output[-1, skip_chan:] - target[-1, skip_chan:],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', title=['in', 'out', 'tgt', 'out - tgt'],
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
**args.misc_kwargs, **args.misc_kwargs,
) )
logger.add_figure('fig/train', fig, global_step=epoch+1) logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf() fig.clf()
fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], fig = plt_power(
**args.misc_kwargs) input, output[:, skip_chan:], target[:, skip_chan:],
label=['in', 'out', 'tgt'],
**args.misc_kwargs,
)
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
fig.clf() fig.clf()
#fig = plt_power(1.0, #fig = plt_power(1.0,
# dis=[input, lag_out, lag_tgt], # dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
# label=['in', 'out', 'tgt'], # label=['in', 'out', 'tgt'],
# **args.misc_kwargs, # **args.misc_kwargs,
#) #)
@ -361,43 +343,34 @@ def validate(epoch, loader, model, criterion, logger, device, args):
input = resample(input, model.module.scale_factor, narrow=False) input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target loss = criterion(output, target)
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) epoch_loss[0] += loss.detach()
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.detach()
epoch_loss[1] += eul_loss.detach()
epoch_loss[2] += loss.detach()
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size epoch_loss /= len(loader) * world_size
if rank == 0: if rank == 0:
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0], logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2],
global_step=epoch+1) global_step=epoch+1)
fig = plt_slices( fig = plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1], input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1], output[-1, skip_chan:] - target[-1, skip_chan:],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', title=['in', 'out', 'tgt', 'out - tgt'],
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
**args.misc_kwargs, **args.misc_kwargs,
) )
logger.add_figure('fig/val', fig, global_step=epoch+1) logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf() fig.clf()
fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], fig = plt_power(
**args.misc_kwargs) input, output[:, skip_chan:], target[:, skip_chan:],
label=['in', 'out', 'tgt'],
**args.misc_kwargs,
)
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
fig.clf() fig.clf()
#fig = plt_power(1.0, #fig = plt_power(1.0,
# dis=[input, lag_out, lag_tgt], # dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
# label=['in', 'out', 'tgt'], # label=['in', 'out', 'tgt'],
# **args.misc_kwargs, # **args.misc_kwargs,
#) #)