map2map/map2map/train.py
Yin Li 85efb9e3a3 Add autograd detect anomaly flag, and test cudnn.deterministic
cudnn.deterministic is only 10% slower
2020-09-12 18:26:44 -04:00

436 lines
15 KiB
Python

import os
import socket
import time
import sys
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler
from . import models
from .models import narrow_cast, resample, lag2eul
from .utils import import_attr, load_model_state_dict, plt_slices, plt_power
ckpt_link = 'checkpoint.pt'
def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
elif 'SLURM_JOB_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
else:
raise KeyError('missing node counts in slurm env')
args.gpus_per_node = torch.cuda.device_count()
args.world_size = args.nodes * args.gpus_per_node
node = int(os.environ['SLURM_NODEID'])
if args.gpus_per_node < 1:
raise RuntimeError('GPU not found on node {}'.format(node))
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, node, args):
#device = torch.device('cuda', local_rank)
#torch.cuda.device(device) # env var recommended over this
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)
device = torch.device('cuda', 0)
rank = args.gpus_per_node * node + local_rank
# Need randomness across processes, for sampler, augmentation, noise etc.
# Note DDP broadcasts initial model states from rank 0
torch.manual_seed(args.seed + rank)
# good practice to disable cudnn.benchmark if enabling cudnn.deterministic
#torch.backends.cudnn.deterministic = True
dist_init(rank, args)
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
shuffle=False,
sampler=train_sampler,
num_workers=args.loader_workers,
pin_memory=True,
)
if args.val:
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,
shuffle=False,
sampler=val_sampler,
num_workers=args.loader_workers,
pin_memory=True,
)
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
criterion = import_attr(args.criterion, nn, models,
callback_at=args.callback_at)
criterion = criterion()
criterion.to(device)
optimizer = import_attr(args.optimizer, optim, callback_at=args.callback_at)
optimizer = optimizer(
model.parameters(),
lr=args.lr,
**args.optimizer_args,
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, **args.scheduler_args)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
if args.init_weight_std is not None:
model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
else:
state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch']
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
torch.set_rng_state(state['rng'].cpu()) # move rng state back
if rank == 0:
min_loss = state['min_loss']
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
del state
torch.backends.cudnn.benchmark = True
if args.detect_anomaly:
torch.autograd.set_detect_anomaly(True)
logger = None
if rank == 0:
logger = SummaryWriter()
if rank == 0:
pprint(vars(args))
sys.stdout.flush()
for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, criterion,
optimizer, scheduler, logger, device, args)
epoch_loss = train_loss
if args.val:
val_loss = validate(epoch, val_loader, model, criterion,
logger, device, args)
#epoch_loss = val_loss
if args.reduce_lr_on_plateau:
scheduler.step(epoch_loss[2])
if rank == 0:
logger.flush()
if min_loss is None or epoch_loss[2] < min_loss:
min_loss = epoch_loss[2]
state = {
'epoch': epoch + 1,
'model': model.module.state_dict(),
'rng': torch.get_rng_state(),
'min_loss': min_loss,
}
state_file = 'state_{}.pt'.format(epoch + 1)
torch.save(state, state_file)
del state
tmp_link = '{}.pt'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)
dist.destroy_process_group()
def train(epoch, loader, model, criterion,
optimizer, scheduler, logger, device, args):
model.train()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (input, target) in enumerate(loader):
batch = epoch * len(loader) + i + 1
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if batch == 1 and rank == 0:
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
if batch == 1 and rank == 0:
print('narrowed shape :', output.shape, flush=True)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.item()
epoch_loss[1] += eul_loss.item()
epoch_loss[2] += loss.item()
optimizer.zero_grad()
torch.log(loss).backward() # NOTE actual loss is log(loss)
optimizer.step()
grads = get_grads(model)
if batch % args.log_interval == 0:
dist.all_reduce(lag_loss)
dist.all_reduce(eul_loss)
dist.all_reduce(loss)
lag_loss /= world_size
eul_loss /= world_size
loss /= world_size
if rank == 0:
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/lxe', loss.item(),
global_step=batch)
logger.add_scalar('grad/first', grads[0], global_step=batch)
logger.add_scalar('grad/last', grads[-1], global_step=batch)
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2],
global_step=epoch+1)
fig = plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
)
logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss
def validate(epoch, loader, model, criterion, logger, device, args):
model.eval()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
with torch.no_grad():
for input, target in loader:
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.item()
epoch_loss[1] += eul_loss.item()
epoch_loss[2] += loss.item()
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2],
global_step=epoch+1)
fig = plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
)
logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss
def dist_init(rank, args):
dist_file = 'dist_addr'
if rank == 0:
addr = socket.gethostname()
with socket.socket() as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((addr, 0))
_, port = s.getsockname()
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
with open(dist_file, mode='w') as f:
f.write(args.dist_addr)
if rank != 0:
while not os.path.exists(dist_file):
time.sleep(1)
with open(dist_file, mode='r') as f:
args.dist_addr = f.read()
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_addr,
world_size=args.world_size,
rank=rank,
)
dist.barrier()
if rank == 0:
os.remove(dist_file)
def set_requires_grad(module, requires_grad=False):
for param in module.parameters():
param.requires_grad = requires_grad
def get_grads(model):
"""gradients of the weights of the first and the last layer
"""
grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
return grads