map2map/map2map/train.py
Yin Li 154376d95a Add tgt_pad, rename pad to in_pad
tgt_pad can be useful for scale_factor > 1
2020-09-12 18:24:36 -04:00

433 lines
14 KiB
Python

import os
import socket
import time
import sys
from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices, plt_power
from . import models
from .models import narrow_cast, resample, lag2eul
from .utils import import_attr, load_model_state_dict
ckpt_link = 'checkpoint.pt'
def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
elif 'SLURM_JOB_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
else:
raise KeyError('missing node counts in slurm env')
args.gpus_per_node = torch.cuda.device_count()
args.world_size = args.nodes * args.gpus_per_node
node = int(os.environ['SLURM_NODEID'])
if args.gpus_per_node < 1:
raise RuntimeError('GPU not found on node {}'.format(node))
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, node, args):
#device = torch.device('cuda', local_rank)
#torch.cuda.device(device) # env var recommended over this
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)
device = torch.device('cuda', 0)
rank = args.gpus_per_node * node + local_rank
# Need randomness across processes, for sampler, augmentation, noise etc.
# Note DDP broadcasts initial model states from rank 0
torch.manual_seed(args.seed + rank)
#torch.backends.cudnn.deterministic = True # NOTE: test perf
dist_init(rank, args)
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
shuffle=False,
sampler=train_sampler,
num_workers=args.loader_workers,
pin_memory=True,
)
if args.val:
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
div_data=args.div_data,
div_shuffle_dist=args.div_shuffle_dist)
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,
shuffle=False,
sampler=val_sampler,
num_workers=args.loader_workers,
pin_memory=True,
)
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
criterion = import_attr(args.criterion, nn, models,
callback_at=args.callback_at)
criterion = criterion()
criterion.to(device)
optimizer = import_attr(args.optimizer, optim, callback_at=args.callback_at)
optimizer = optimizer(
model.parameters(),
lr=args.lr,
**args.optimizer_args,
)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, **args.scheduler_args)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
if args.init_weight_std is not None:
model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
else:
state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch']
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
torch.set_rng_state(state['rng'].cpu()) # move rng state back
if rank == 0:
min_loss = state['min_loss']
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
del state
torch.backends.cudnn.benchmark = True # NOTE: test perf
logger = None
if rank == 0:
logger = SummaryWriter()
if rank == 0:
pprint(vars(args))
sys.stdout.flush()
for epoch in range(start_epoch, args.epochs):
train_sampler.set_epoch(epoch)
train_loss = train(epoch, train_loader, model, criterion,
optimizer, scheduler, logger, device, args)
epoch_loss = train_loss
if args.val:
val_loss = validate(epoch, val_loader, model, criterion,
logger, device, args)
#epoch_loss = val_loss
if args.reduce_lr_on_plateau:
scheduler.step(epoch_loss[2])
if rank == 0:
logger.flush()
if min_loss is None or epoch_loss[2] < min_loss:
min_loss = epoch_loss[2]
state = {
'epoch': epoch + 1,
'model': model.module.state_dict(),
'rng': torch.get_rng_state(),
'min_loss': min_loss,
}
state_file = 'state_{}.pt'.format(epoch + 1)
torch.save(state, state_file)
del state
tmp_link = '{}.pt'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)
dist.destroy_process_group()
def train(epoch, loader, model, criterion,
optimizer, scheduler, logger, device, args):
model.train()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (input, target) in enumerate(loader):
batch = epoch * len(loader) + i + 1
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if batch == 1 and rank == 0:
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
if batch == 1 and rank == 0:
print('narrowed shape :', output.shape, flush=True)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.item()
epoch_loss[1] += eul_loss.item()
epoch_loss[2] += loss.item()
optimizer.zero_grad()
torch.log(loss).backward() # NOTE actual loss is log(loss)
optimizer.step()
grads = get_grads(model)
if batch % args.log_interval == 0:
dist.all_reduce(lag_loss)
dist.all_reduce(eul_loss)
dist.all_reduce(loss)
lag_loss /= world_size
eul_loss /= world_size
loss /= world_size
if rank == 0:
logger.add_scalar('loss/batch/train/lag', lag_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/eul', eul_loss.item(),
global_step=batch)
logger.add_scalar('loss/batch/train/lxe', loss.item(),
global_step=batch)
logger.add_scalar('grad/first', grads[0], global_step=batch)
logger.add_scalar('grad/last', grads[-1], global_step=batch)
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/train/lag', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/train/lxe', epoch_loss[2],
global_step=epoch+1)
fig = plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
)
logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss
def validate(epoch, loader, model, criterion, logger, device, args):
model.eval()
rank = dist.get_rank()
world_size = dist.get_world_size()
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
with torch.no_grad():
for input, target in loader:
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.item()
epoch_loss[1] += eul_loss.item()
epoch_loss[2] += loss.item()
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size
if rank == 0:
logger.add_scalar('loss/epoch/val/lag', epoch_loss[0],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/eul', epoch_loss[1],
global_step=epoch+1)
logger.add_scalar('loss/epoch/val/lxe', epoch_loss[2],
global_step=epoch+1)
fig = plt_slices(
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
)
logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss
def dist_init(rank, args):
dist_file = 'dist_addr'
if rank == 0:
addr = socket.gethostname()
with socket.socket() as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((addr, 0))
_, port = s.getsockname()
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
with open(dist_file, mode='w') as f:
f.write(args.dist_addr)
if rank != 0:
while not os.path.exists(dist_file):
time.sleep(1)
with open(dist_file, mode='r') as f:
args.dist_addr = f.read()
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_addr,
world_size=args.world_size,
rank=rank,
)
dist.barrier()
if rank == 0:
os.remove(dist_file)
def set_requires_grad(module, requires_grad=False):
for param in module.parameters():
param.requires_grad = requires_grad
def get_grads(model):
"""gradients of the weights of the first and the last layer
"""
grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
return grads