Add lag x eul loss tracking to tensorboard

This commit is contained in:
Yin Li 2020-07-21 21:10:34 -07:00
parent 56f8fa932b
commit ec46f41ba5

View File

@ -300,6 +300,9 @@ def train(epoch, loader, model, lag2eul, criterion,
global_step=batch)
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/lxe',
lag_loss.item() * eul_loss.item(),
global_step=batch)
logger.add_scalar('grad/lag/first', lag_grads[0],
global_step=batch)
@ -317,6 +320,8 @@ def train(epoch, loader, model, lag2eul, criterion,
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
global_step=epoch+1)
logger.add_figure('fig/epoch/train', plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
@ -365,6 +370,8 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
global_step=epoch+1)
logger.add_figure('fig/epoch/val', plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],