Add loglxe loss as log of product of Lagrangian and Eulerian losses
This commit is contained in:
parent
13edf3b96d
commit
9c8b331bf5
1 changed files with 33 additions and 64 deletions
|
@ -133,20 +133,13 @@ def gpu_worker(local_rank, node, args):
|
|||
criterion.to(device)
|
||||
|
||||
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||
lag_optimizer = optimizer(
|
||||
optimizer = optimizer(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
**args.optimizer_args,
|
||||
)
|
||||
eul_optimizer = optimizer(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
**args.optimizer_args,
|
||||
)
|
||||
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
lag_optimizer, **args.scheduler_args)
|
||||
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
eul_optimizer, **args.scheduler_args)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, **args.scheduler_args)
|
||||
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state):
|
||||
|
@ -201,8 +194,7 @@ def gpu_worker(local_rank, node, args):
|
|||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_loss = train(epoch, train_loader, model, lag2eul, criterion,
|
||||
lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler,
|
||||
logger, device, args)
|
||||
optimizer, scheduler, logger, device, args)
|
||||
epoch_loss = train_loss
|
||||
|
||||
if args.val:
|
||||
|
@ -211,14 +203,13 @@ def gpu_worker(local_rank, node, args):
|
|||
#epoch_loss = val_loss
|
||||
|
||||
if args.reduce_lr_on_plateau:
|
||||
lag_scheduler.step(epoch_loss[0])
|
||||
eul_scheduler.step(epoch_loss[1])
|
||||
scheduler.step(epoch_loss[2])
|
||||
|
||||
if rank == 0:
|
||||
logger.flush()
|
||||
|
||||
if min_loss is None or torch.prod(epoch_loss) < torch.prod(min_loss):
|
||||
min_loss = epoch_loss
|
||||
if min_loss is None or epoch_loss[2] < min_loss:
|
||||
min_loss = epoch_loss[2]
|
||||
|
||||
state = {
|
||||
'epoch': epoch + 1,
|
||||
|
@ -239,14 +230,13 @@ def gpu_worker(local_rank, node, args):
|
|||
|
||||
|
||||
def train(epoch, loader, model, lag2eul, criterion,
|
||||
lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler,
|
||||
logger, device, args):
|
||||
optimizer, scheduler, logger, device, args):
|
||||
model.train()
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
epoch_loss = torch.zeros(2, dtype=torch.float64, device=device)
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
input = input.to(device, non_blocking=True)
|
||||
|
@ -264,59 +254,38 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||
|
||||
if i % 2 == 0:
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
loss = lag_loss * eul_loss
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
epoch_loss[1] += eul_loss.item()
|
||||
epoch_loss[2] += loss.item()
|
||||
|
||||
with torch.no_grad():
|
||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
epoch_loss[1] += eul_loss.item()
|
||||
|
||||
lag_optimizer.zero_grad()
|
||||
lag_loss.backward()
|
||||
lag_optimizer.step()
|
||||
lag_grads = get_grads(model)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
|
||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
epoch_loss[1] += eul_loss.item()
|
||||
|
||||
eul_optimizer.zero_grad()
|
||||
eul_loss.backward()
|
||||
eul_optimizer.step()
|
||||
eul_grads = get_grads(model)
|
||||
optimizer.zero_grad()
|
||||
torch.log(loss).backward() # NOTE actual loss is log(loss)
|
||||
optimizer.step()
|
||||
grads = get_grads(model)
|
||||
|
||||
batch = epoch * len(loader) + i + 1
|
||||
if batch % args.log_interval == 0 and batch >= 2:
|
||||
if batch % args.log_interval == 0:
|
||||
dist.all_reduce(lag_loss)
|
||||
dist.all_reduce(eul_loss)
|
||||
dist.all_reduce(loss)
|
||||
lag_loss /= world_size
|
||||
eul_loss /= world_size
|
||||
loss /= world_size
|
||||
if rank == 0:
|
||||
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
|
||||
global_step=batch)
|
||||
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
||||
global_step=batch)
|
||||
logger.add_scalar('loss/batch/train/lxe',
|
||||
lag_loss.item() * eul_loss.item(),
|
||||
logger.add_scalar('loss/batch/train/lxe', loss.item(),
|
||||
global_step=batch)
|
||||
|
||||
logger.add_scalar('grad/lag/first', lag_grads[0],
|
||||
global_step=batch)
|
||||
logger.add_scalar('grad/lag/last', lag_grads[-1],
|
||||
global_step=batch)
|
||||
logger.add_scalar('grad/eul/first', eul_grads[0],
|
||||
global_step=batch)
|
||||
logger.add_scalar('grad/eul/last', eul_grads[-1],
|
||||
global_step=batch)
|
||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
|
@ -325,7 +294,7 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
|
||||
logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2],
|
||||
global_step=epoch+1)
|
||||
|
||||
fig = plt_slices(
|
||||
|
@ -346,7 +315,7 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
epoch_loss = torch.zeros(2, dtype=torch.float64, device=device)
|
||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
for input, target in loader:
|
||||
|
@ -361,14 +330,14 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||
input, output, target = narrow_cast(input, output, target)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
|
||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||
|
||||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
loss = lag_loss * eul_loss
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
epoch_loss[1] += eul_loss.item()
|
||||
epoch_loss[2] += loss.item()
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
|
@ -377,7 +346,7 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
|
||||
logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2],
|
||||
global_step=epoch+1)
|
||||
|
||||
fig = plt_slices(
|
||||
|
|
Loading…
Reference in a new issue