map2map/map2map/train.py

425 lines
14 KiB
Python

import os
import socket
import time
import sys
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset
from .data.figures import plt_slices
from . import models
from .models import narrow_cast, resample, Lag2Eul
from .utils import import_attr, load_model_state_dict
ckpt_link = 'checkpoint.pt'
def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
elif 'SLURM_JOB_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
else:
raise KeyError('missing node counts in slurm env')
args.gpus_per_node = torch.cuda.device_count()
args.world_size = args.nodes * args.gpus_per_node
node = int(os.environ['SLURM_NODEID'])
if args.gpus_per_node < 1:
raise RuntimeError('GPU not found on node {}'.format(node))
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, node, args):
device = torch.device('cuda', local_rank)
torch.cuda.device(device)
rank = args.gpus_per_node * node + local_rank
# Need randomness across processes, for sampler, augmentation, noise etc.
# Note DDP broadcasts initial model states from rank 0
torch.manual_seed(args.seed + rank)
#torch.backends.cudnn.deterministic = True # NOTE: test perf
dist_init(rank, args)
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
)
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
shuffle=False,
sampler=train_sampler,
num_workers=args.loader_workers,
pin_memory=True,
)
if args.val:
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
scale_factor=args.scale_factor,
)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,
shuffle=False,
sampler=val_sampler,
num_workers=args.loader_workers,
pin_memory=True,
)
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan))
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
lag2eul = Lag2Eul()
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
criterion = criterion()
criterion.to(device)
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
lag_optimizer = optimizer(
model.parameters(),
lr=args.lr,
#momentum=args.momentum,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
)
eul_optimizer = optimizer(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
)
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer,
factor=0.1, patience=10, verbose=True)
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer,
factor=0.1, patience=10, verbose=True)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
if args.init_weight_std is not None:
model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
else:
state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch']
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
torch.set_rng_state(state['rng'].cpu()) # move rng state back
if rank == 0:
min_loss = state['min_loss']
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
del state
torch.backends.cudnn.benchmark = True # NOTE: test perf
logger = None
if rank == 0:
logger = SummaryWriter()
if rank == 0:
pprint(vars(args))
sys.stdout.flush()
for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, lag2eul, criterion,
lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler,
logger, device, args)
epoch_loss = train_loss
if args.val:
val_loss = validate(epoch, val_loader, model, lag2eul, criterion,
logger, device, args)
epoch_loss = val_loss
if args.reduce_lr_on_plateau:
lag_scheduler.step(epoch_loss[0])
eul_scheduler.step(epoch_loss[1])
if rank == 0:
logger.flush()
if min_loss is None or torch.prod(epoch_loss) < torch.prod(min_loss):
min_loss = epoch_loss
state = {
'epoch': epoch + 1,
'model': model.module.state_dict(),
'rng': torch.get_rng_state(),
'min_loss': min_loss,
}
state_file = 'state_{}.pt'.format(epoch + 1)
torch.save(state, state_file)
del state
tmp_link = '{}.pt'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)
dist.destroy_process_group()
def train(epoch, loader, model, lag2eul, criterion,
lag_optimizer, eul_optimizer, lag_scheduler, eul_scheduler,
logger, device, args):
model.train()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(2, dtype=torch.float64, device=device)
for i, (input, target) in enumerate(loader):
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if epoch == 0 and i == 0 and rank == 0:
print('input.shape =', input.shape)
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target
if i % 2 == 0:
lag_loss = criterion(lag_out, lag_tgt)
epoch_loss[0] += lag_loss.item()
with torch.no_grad():
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
epoch_loss[1] += eul_loss.item()
lag_optimizer.zero_grad()
lag_loss.backward()
lag_optimizer.step()
lag_grads = get_grads(model)
else:
with torch.no_grad():
lag_loss = criterion(lag_out, lag_tgt)
epoch_loss[0] += lag_loss.item()
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
epoch_loss[1] += eul_loss.item()
eul_optimizer.zero_grad()
eul_loss.backward()
eul_optimizer.step()
eul_grads = get_grads(model)
batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0 and batch >= 2:
dist.all_reduce(lag_loss)
dist.all_reduce(eul_loss)
lag_loss /= world_size
eul_loss /= world_size
if rank == 0:
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
global_step=batch)
logger.add_scalar('grad/lag/first', lag_grads[0],
global_step=batch)
logger.add_scalar('grad/lag/last', lag_grads[-1],
global_step=batch)
logger.add_scalar('grad/eul/first', eul_grads[0],
global_step=batch)
logger.add_scalar('grad/eul/last', eul_grads[-1],
global_step=batch)
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_figure('fig/epoch/train', plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
), global_step=epoch+1)
return epoch_loss
def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
model.eval()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(2, dtype=torch.float64, device=device)
with torch.no_grad():
for input, target in loader:
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target
lag_loss = criterion(lag_out, lag_tgt)
epoch_loss[0] += lag_loss.item()
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
epoch_loss[1] += eul_loss.item()
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_figure('fig/epoch/val', plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
), global_step=epoch+1)
return epoch_loss
def dist_init(rank, args):
dist_file = 'dist_addr'
if rank == 0:
addr = socket.gethostname()
with socket.socket() as s:
s.bind((addr, 0))
_, port = s.getsockname()
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
with open(dist_file, mode='w') as f:
f.write(args.dist_addr)
if rank != 0:
while not os.path.exists(dist_file):
time.sleep(1)
with open(dist_file, mode='r') as f:
args.dist_addr = f.read()
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_addr,
world_size=args.world_size,
rank=rank,
)
dist.barrier()
if rank == 0:
os.remove(dist_file)
def set_requires_grad(module, requires_grad=False):
for param in module.parameters():
param.requires_grad = requires_grad
def get_grads(model):
"""gradients of the weights of the first and the last layer
"""
grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
return grads