Add Lag2Eul to training
This commit is contained in:
parent
9d2cd5383b
commit
bab3f08aa5
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from .data import FieldDataset, GroupedRandomSampler
|
||||
from .data.figures import fig3d
|
||||
from . import models
|
||||
from .models import (narrow_like,
|
||||
from .models import (narrow_like, Lag2Eul,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm,
|
||||
InstanceNoise)
|
||||
@ -139,6 +139,8 @@ def gpu_worker(local_rank, node, args):
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
dis2den = Lag2Eul()
|
||||
|
||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||
criterion = criterion()
|
||||
criterion.to(device)
|
||||
@ -248,14 +250,14 @@ def gpu_worker(local_rank, node, args):
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_loss = train(epoch, train_loader,
|
||||
model, criterion, optimizer, scheduler,
|
||||
model, dis2den, criterion, optimizer, scheduler,
|
||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||
logger, device, args)
|
||||
epoch_loss = train_loss
|
||||
|
||||
if args.val:
|
||||
val_loss = validate(epoch, val_loader,
|
||||
model, criterion, adv_model, adv_criterion,
|
||||
model, dis2den, criterion, adv_model, adv_criterion,
|
||||
logger, device, args)
|
||||
epoch_loss = val_loss
|
||||
|
||||
@ -300,7 +302,7 @@ def gpu_worker(local_rank, node, args):
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||
logger, device, args):
|
||||
model.train()
|
||||
@ -332,6 +334,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
|
||||
output, target = dis2den(output, target)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
@ -447,7 +451,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
def validate(epoch, loader, model, dis2den, criterion, adv_model, adv_criterion,
|
||||
logger, device, args):
|
||||
model.eval()
|
||||
if args.adv:
|
||||
@ -473,6 +477,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
input = narrow_like(input, output)
|
||||
|
||||
output, target = dis2den(output, target)
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user