Remove Lagrangian part
This commit is contained in:
parent
0d4ae3424e
commit
352482d475
@ -14,7 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .data import FieldDataset, DistFieldSampler
|
||||
from . import models
|
||||
from .models import narrow_cast, resample, lag2eul
|
||||
from .models import narrow_cast, resample
|
||||
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
|
||||
|
||||
|
||||
@ -267,36 +267,20 @@ def train(epoch, loader, model, criterion,
|
||||
if batch <= 5 and rank == 0:
|
||||
print('narrowed shape :', output.shape)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||
if batch <= 5 and rank == 0:
|
||||
print('Eulerian shape :', eul_out.shape, flush=True)
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.detach()
|
||||
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
loss = lag_loss * eul_loss
|
||||
epoch_loss[0] += lag_loss.detach()
|
||||
epoch_loss[1] += eul_loss.detach()
|
||||
epoch_loss[2] += loss.detach()
|
||||
|
||||
optimizer.zero_grad()
|
||||
torch.log(loss).backward() # NOTE actual loss is log(loss)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
grads = get_grads(model)
|
||||
|
||||
if batch % args.log_interval == 0:
|
||||
dist.all_reduce(lag_loss)
|
||||
dist.all_reduce(eul_loss)
|
||||
dist.all_reduce(loss)
|
||||
lag_loss /= world_size
|
||||
eul_loss /= world_size
|
||||
loss /= world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
|
||||
global_step=batch)
|
||||
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
||||
global_step=batch)
|
||||
logger.add_scalar('loss/batch/train/lxe', loss.item(),
|
||||
logger.add_scalar('loss/batch/train', loss.item(),
|
||||
global_step=batch)
|
||||
|
||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||
@ -305,30 +289,28 @@ def train(epoch, loader, model, criterion,
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2],
|
||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
|
||||
fig = plt_slices(
|
||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
||||
**args.misc_kwargs)
|
||||
fig = plt_power(
|
||||
input, output[:, skip_chan:], target[:, skip_chan:],
|
||||
label=['in', 'out', 'tgt'],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(1.0,
|
||||
# dis=[input, lag_out, lag_tgt],
|
||||
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||
# label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs,
|
||||
#)
|
||||
@ -361,43 +343,34 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
||||
input = resample(input, model.module.scale_factor, narrow=False)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
loss = lag_loss * eul_loss
|
||||
epoch_loss[0] += lag_loss.detach()
|
||||
epoch_loss[1] += eul_loss.detach()
|
||||
epoch_loss[2] += loss.detach()
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.detach()
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2],
|
||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
|
||||
fig = plt_slices(
|
||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
||||
**args.misc_kwargs)
|
||||
fig = plt_power(
|
||||
input, output[:, skip_chan:], target[:, skip_chan:],
|
||||
label=['in', 'out', 'tgt'],
|
||||
**args.misc_kwargs,
|
||||
)
|
||||
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(1.0,
|
||||
# dis=[input, lag_out, lag_tgt],
|
||||
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||
# label=['in', 'out', 'tgt'],
|
||||
# **args.misc_kwargs,
|
||||
#)
|
||||
|
Loading…
Reference in New Issue
Block a user