From 607bcf3f4c2db4681e7f0eb1b79e51ffcd4c183e Mon Sep 17 00:00:00 2001 From: Yin Li Date: Fri, 10 Jul 2020 12:50:07 -0400 Subject: [PATCH] Add Lag2Eul to training --- map2map/train.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index 8e1e010..ea59b47 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset from .data.figures import plt_slices from . import models -from .models import (narrow_cast, resample, +from .models import (narrow_cast, resample, Lag2Eul adv_model_wrapper, adv_criterion_wrapper, add_spectral_norm, rm_spectral_norm, InstanceNoise) @@ -121,6 +121,8 @@ def gpu_worker(local_rank, node, args): model = DistributedDataParallel(model, device_ids=[device], process_group=dist.new_group()) + dis2den = Lag2Eul() + criterion = import_attr(args.criterion, nn.__name__, args.callback_at) criterion = criterion() criterion.to(device) @@ -229,14 +231,14 @@ def gpu_worker(local_rank, node, args): train_sampler.set_epoch(epoch) train_loss = train(epoch, train_loader, - model, criterion, optimizer, scheduler, + model, dis2den, criterion, optimizer, scheduler, adv_model, adv_criterion, adv_optimizer, adv_scheduler, logger, device, args) epoch_loss = train_loss if args.val: val_loss = validate(epoch, val_loader, - model, criterion, adv_model, adv_criterion, + model, dis2den, criterion, adv_model, adv_criterion, logger, device, args) epoch_loss = val_loss @@ -272,7 +274,7 @@ def gpu_worker(local_rank, node, args): dist.destroy_process_group() -def train(epoch, loader, model, criterion, optimizer, scheduler, +def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler, adv_model, adv_criterion, adv_optimizer, adv_scheduler, logger, device, args): model.train() @@ -307,6 +309,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) + output, target = dis2den(output, target) + loss = criterion(output, target) epoch_loss[0] += loss.item() @@ -418,7 +422,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, return epoch_loss -def validate(epoch, loader, model, criterion, adv_model, adv_criterion, +def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion, logger, device, args): model.eval() if args.adv: @@ -443,6 +447,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) + output, target = dis2den(output, target) + loss = criterion(output, target) epoch_loss[0] += loss.item()