Add Lag2Eul to training
This commit is contained in:
parent
40220e9248
commit
607bcf3f4c
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from .data.figures import plt_slices
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_cast, resample,
|
from .models import (narrow_cast, resample, Lag2Eul
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm,
|
add_spectral_norm, rm_spectral_norm,
|
||||||
InstanceNoise)
|
InstanceNoise)
|
||||||
@ -121,6 +121,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
|
dis2den = Lag2Eul()
|
||||||
|
|
||||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
@ -229,14 +231,14 @@ def gpu_worker(local_rank, node, args):
|
|||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader,
|
train_loss = train(epoch, train_loader,
|
||||||
model, criterion, optimizer, scheduler,
|
model, dis2den, criterion, optimizer, scheduler,
|
||||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||||
logger, device, args)
|
logger, device, args)
|
||||||
epoch_loss = train_loss
|
epoch_loss = train_loss
|
||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_loss = validate(epoch, val_loader,
|
val_loss = validate(epoch, val_loader,
|
||||||
model, criterion, adv_model, adv_criterion,
|
model, dis2den, criterion, adv_model, adv_criterion,
|
||||||
logger, device, args)
|
logger, device, args)
|
||||||
epoch_loss = val_loss
|
epoch_loss = val_loss
|
||||||
|
|
||||||
@ -272,7 +274,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, loader, model, criterion, optimizer, scheduler,
|
def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
||||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||||
logger, device, args):
|
logger, device, args):
|
||||||
model.train()
|
model.train()
|
||||||
@ -307,6 +309,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
|
output, target = dis2den(output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
@ -418,7 +422,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
|
||||||
logger, device, args):
|
logger, device, args):
|
||||||
model.eval()
|
model.eval()
|
||||||
if args.adv:
|
if args.adv:
|
||||||
@ -443,6 +447,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
|
output, target = dis2den(output, target)
|
||||||
|
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user