Remove Lagrangian part
This commit is contained in:
parent
0d4ae3424e
commit
352482d475
@ -14,7 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from .data import FieldDataset, DistFieldSampler
|
from .data import FieldDataset, DistFieldSampler
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_cast, resample, lag2eul
|
from .models import narrow_cast, resample
|
||||||
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
|
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
|
||||||
|
|
||||||
|
|
||||||
@ -267,36 +267,20 @@ def train(epoch, loader, model, criterion,
|
|||||||
if batch <= 5 and rank == 0:
|
if batch <= 5 and rank == 0:
|
||||||
print('narrowed shape :', output.shape)
|
print('narrowed shape :', output.shape)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
loss = criterion(output, target)
|
||||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
epoch_loss[0] += loss.detach()
|
||||||
if batch <= 5 and rank == 0:
|
|
||||||
print('Eulerian shape :', eul_out.shape, flush=True)
|
|
||||||
|
|
||||||
lag_loss = criterion(lag_out, lag_tgt)
|
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
|
||||||
loss = lag_loss * eul_loss
|
|
||||||
epoch_loss[0] += lag_loss.detach()
|
|
||||||
epoch_loss[1] += eul_loss.detach()
|
|
||||||
epoch_loss[2] += loss.detach()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
torch.log(loss).backward() # NOTE actual loss is log(loss)
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
grads = get_grads(model)
|
grads = get_grads(model)
|
||||||
|
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0:
|
||||||
dist.all_reduce(lag_loss)
|
|
||||||
dist.all_reduce(eul_loss)
|
|
||||||
dist.all_reduce(loss)
|
dist.all_reduce(loss)
|
||||||
lag_loss /= world_size
|
|
||||||
eul_loss /= world_size
|
|
||||||
loss /= world_size
|
loss /= world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
|
logger.add_scalar('loss/batch/train', loss.item(),
|
||||||
global_step=batch)
|
|
||||||
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
|
||||||
global_step=batch)
|
|
||||||
logger.add_scalar('loss/batch/train/lxe', loss.item(),
|
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
|
|
||||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||||
@ -305,30 +289,28 @@ def train(epoch, loader, model, criterion,
|
|||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
|
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
|
||||||
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
|
||||||
global_step=epoch+1)
|
|
||||||
logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2],
|
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
fig = plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
fig = plt_power(
|
||||||
**args.misc_kwargs)
|
input, output[:, skip_chan:], target[:, skip_chan:],
|
||||||
|
label=['in', 'out', 'tgt'],
|
||||||
|
**args.misc_kwargs,
|
||||||
|
)
|
||||||
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(1.0,
|
#fig = plt_power(1.0,
|
||||||
# dis=[input, lag_out, lag_tgt],
|
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||||
# label=['in', 'out', 'tgt'],
|
# label=['in', 'out', 'tgt'],
|
||||||
# **args.misc_kwargs,
|
# **args.misc_kwargs,
|
||||||
#)
|
#)
|
||||||
@ -361,43 +343,34 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
loss = criterion(output, target)
|
||||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
epoch_loss[0] += loss.detach()
|
||||||
|
|
||||||
lag_loss = criterion(lag_out, lag_tgt)
|
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
|
||||||
loss = lag_loss * eul_loss
|
|
||||||
epoch_loss[0] += lag_loss.detach()
|
|
||||||
epoch_loss[1] += eul_loss.detach()
|
|
||||||
epoch_loss[2] += loss.detach()
|
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
|
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
|
||||||
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
|
||||||
global_step=epoch+1)
|
|
||||||
logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2],
|
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
fig = plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
|
||||||
**args.misc_kwargs,
|
**args.misc_kwargs,
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'],
|
fig = plt_power(
|
||||||
**args.misc_kwargs)
|
input, output[:, skip_chan:], target[:, skip_chan:],
|
||||||
|
label=['in', 'out', 'tgt'],
|
||||||
|
**args.misc_kwargs,
|
||||||
|
)
|
||||||
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
#fig = plt_power(1.0,
|
#fig = plt_power(1.0,
|
||||||
# dis=[input, lag_out, lag_tgt],
|
# dis=[input, output[:, skip_chan:], target[:, skip_chan:]],
|
||||||
# label=['in', 'out', 'tgt'],
|
# label=['in', 'out', 'tgt'],
|
||||||
# **args.misc_kwargs,
|
# **args.misc_kwargs,
|
||||||
#)
|
#)
|
||||||
|
Loading…
Reference in New Issue
Block a user