Add lagrangian and eulerian alternate training
This commit is contained in:
parent
337d65de68
commit
2d5234812b
138
map2map/train.py
138
map2map/train.py
@ -118,21 +118,29 @@ def gpu_worker(local_rank, node, args):
|
|||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
dis2den = Lag2Eul()
|
lag2eul = Lag2Eul()
|
||||||
|
|
||||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||||
optimizer = optimizer(
|
lag_optimizer = optimizer(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
#momentum=args.momentum,
|
#momentum=args.momentum,
|
||||||
betas=(0.5, 0.999),
|
betas=(0.9, 0.999),
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
eul_optimizer = optimizer(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
)
|
||||||
|
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer,
|
||||||
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer,
|
||||||
factor=0.1, patience=10, verbose=True)
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
|
||||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||||
@ -187,23 +195,24 @@ def gpu_worker(local_rank, node, args):
|
|||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader,
|
train_loss = train(epoch, train_loader, model, lag2eul, criterion,
|
||||||
model, dis2den, criterion, optimizer, scheduler,
|
lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler,
|
||||||
logger, device, args)
|
logger, device, args)
|
||||||
epoch_loss = train_loss
|
epoch_loss = train_loss
|
||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_loss = validate(epoch, val_loader, model, dis2den, criterion,
|
val_loss = validate(epoch, val_loader, model, lag2eul, criterion,
|
||||||
logger, device, args)
|
logger, device, args)
|
||||||
epoch_loss = val_loss
|
epoch_loss = val_loss
|
||||||
|
|
||||||
if args.reduce_lr_on_plateau:
|
if args.reduce_lr_on_plateau:
|
||||||
scheduler.step(epoch_loss[0])
|
lag_scheduler.step(epoch_loss[0])
|
||||||
|
eul_scheduler.step(epoch_loss[1])
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.flush()
|
logger.flush()
|
||||||
|
|
||||||
if min_loss is None or epoch_loss[0] < min_loss[0]:
|
if min_loss is None or torch.prod(epoch_loss) < torch.prod(min_loss):
|
||||||
min_loss = epoch_loss
|
min_loss = epoch_loss
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
@ -224,14 +233,15 @@ def gpu_worker(local_rank, node, args):
|
|||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
def train(epoch, loader, model, lag2eul, criterion,
|
||||||
|
lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler,
|
||||||
logger, device, args):
|
logger, device, args):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(2, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
@ -248,52 +258,83 @@ def train(epoch, loader, model, dis2den, criterion, optimizer, scheduler,
|
|||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
output, target = dis2den(output, target)
|
lag_out, lag_tgt = output, target
|
||||||
|
|
||||||
loss = criterion(output, target)
|
if i % 2 == 0:
|
||||||
epoch_loss[0] += loss.item()
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
|
epoch_loss[0] += lag_loss.item()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
with torch.no_grad():
|
||||||
loss.backward()
|
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||||
optimizer.step()
|
|
||||||
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
|
epoch_loss[1] += eul_loss.item()
|
||||||
|
|
||||||
|
lag_optimizer.zero_grad()
|
||||||
|
lag_loss.backward()
|
||||||
|
lag_optimizer.step()
|
||||||
|
lag_grads = get_grads(model)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
|
epoch_loss[0] += lag_loss.item()
|
||||||
|
|
||||||
|
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||||
|
|
||||||
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
|
epoch_loss[1] += eul_loss.item()
|
||||||
|
|
||||||
|
eul_optimizer.zero_grad()
|
||||||
|
eul_loss.backward()
|
||||||
|
eul_optimizer.step()
|
||||||
|
eul_grads = get_grads(model)
|
||||||
|
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0 and batch >= 2:
|
||||||
dist.all_reduce(loss)
|
dist.all_reduce(lag_loss)
|
||||||
loss /= world_size
|
dist.all_reduce(eul_loss)
|
||||||
|
lag_loss /= world_size
|
||||||
|
eul_loss /= world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/batch/train', loss.item(),
|
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
|
||||||
|
global_step=batch)
|
||||||
|
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
|
|
||||||
# gradients of the weights of the first and the last layer
|
logger.add_scalar('grad/lag/first', lag_grads[0],
|
||||||
grads = list(p.grad for n, p in model.named_parameters()
|
global_step=batch)
|
||||||
if '.weight' in n)
|
logger.add_scalar('grad/lag/last', lag_grads[-1],
|
||||||
grads = [grads[0], grads[-1]]
|
global_step=batch)
|
||||||
grads = [g.detach().norm().item() for g in grads]
|
logger.add_scalar('grad/eul/first', eul_grads[0],
|
||||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
global_step=batch)
|
||||||
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
logger.add_scalar('grad/eul/last', eul_grads[-1],
|
||||||
|
global_step=batch)
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
|
||||||
|
global_step=epoch+1)
|
||||||
|
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/train', plt_slices(
|
logger.add_figure('fig/epoch/train', plt_slices(
|
||||||
input[-1], output[-1], target[-1], output[-1] - target[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||||
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
), global_step=epoch+1)
|
), global_step=epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
def validate(epoch, loader, model, dis2den, criterion, logger, device, args):
|
def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(2, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input, target in loader:
|
for input, target in loader:
|
||||||
@ -307,20 +348,29 @@ def validate(epoch, loader, model, dis2den, criterion, logger, device, args):
|
|||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
|
||||||
output, target = dis2den(output, target)
|
lag_out, lag_tgt = output, target
|
||||||
|
|
||||||
loss = criterion(output, target)
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += lag_loss.item()
|
||||||
|
|
||||||
|
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||||
|
|
||||||
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
|
epoch_loss[1] += eul_loss.item()
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
|
||||||
|
global_step=epoch+1)
|
||||||
|
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/val', plt_slices(
|
logger.add_figure('fig/epoch/val', plt_slices(
|
||||||
input[-1], output[-1], target[-1], output[-1] - target[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
title=['in', 'out', 'tgt', 'out - tgt'],
|
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||||
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
), global_step=epoch+1)
|
), global_step=epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
@ -363,3 +413,13 @@ def dist_init(rank, args):
|
|||||||
def set_requires_grad(module, requires_grad=False):
|
def set_requires_grad(module, requires_grad=False):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.requires_grad = requires_grad
|
param.requires_grad = requires_grad
|
||||||
|
|
||||||
|
|
||||||
|
def get_grads(model):
|
||||||
|
"""gradients of the weights of the first and the last layer
|
||||||
|
"""
|
||||||
|
grads = list(p.grad for n, p in model.named_parameters()
|
||||||
|
if '.weight' in n)
|
||||||
|
grads = [grads[0], grads[-1]]
|
||||||
|
grads = [g.detach().norm().item() for g in grads]
|
||||||
|
return grads
|
||||||
|
Loading…
Reference in New Issue
Block a user