diff --git a/map2map/train.py b/map2map/train.py index 83ffddc..b2aeb80 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -5,7 +5,6 @@ import sys from pprint import pprint import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim import torch.distributed as dist from torch.multiprocessing import spawn @@ -323,10 +322,10 @@ def train(epoch, loader, model, criterion, logger.add_figure('fig/train', fig, global_step=epoch+1) fig.clf() - #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], - # **args.misc_kwargs) - #logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) - #fig.clf() + fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], + **args.misc_kwargs) + logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) + fig.clf() #fig = plt_power(1.0, # dis=[input, lag_out, lag_tgt], @@ -392,10 +391,10 @@ def validate(epoch, loader, model, criterion, logger, device, args): logger.add_figure('fig/val', fig, global_step=epoch+1) fig.clf() - #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], - # **args.misc_kwargs) - #logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) - #fig.clf() + fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'], + **args.misc_kwargs) + logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) + fig.clf() #fig = plt_power(1.0, # dis=[input, lag_out, lag_tgt], diff --git a/scripts/example-train.slurm b/scripts/example-train.slurm index d6cc894..df3439f 100644 --- a/scripts/example-train.slurm +++ b/scripts/example-train.slurm @@ -2,12 +2,10 @@ #SBATCH --job-name=R2D2 #SBATCH --output=%x-%j.out - #SBATCH --partition=gpu_partition -#SBATCH --gres=gpu:4 - -#SBATCH --exclusive #SBATCH --nodes=2 +#SBATCH --gres=gpu:4 +#SBATCH --exclusive #SBATCH --time=1-00:00:00