Add lag x eul loss tracking to tensorboard
This commit is contained in:
parent
56f8fa932b
commit
ec46f41ba5
@ -297,9 +297,12 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
eul_loss /= world_size
|
eul_loss /= world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
|
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
|
logger.add_scalar('loss/batch/train/lxe',
|
||||||
|
lag_loss.item() * eul_loss.item(),
|
||||||
|
global_step=batch)
|
||||||
|
|
||||||
logger.add_scalar('grad/lag/first', lag_grads[0],
|
logger.add_scalar('grad/lag/first', lag_grads[0],
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
@ -314,9 +317,11 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
|
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
|
||||||
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/train', plt_slices(
|
logger.add_figure('fig/epoch/train', plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
@ -362,9 +367,11 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
|
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
|
||||||
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/val', plt_slices(
|
logger.add_figure('fig/epoch/val', plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
|
Loading…
Reference in New Issue
Block a user