map2map/map2map/train.py

197 lines
6.0 KiB
Python
Raw Normal View History

2019-11-30 21:32:45 +01:00
import os
import shutil
import torch
from torch.multiprocessing import spawn
from torch.distributed import init_process_group, destroy_process_group, all_reduce
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset
from .models import UNet, narrow_like
def node_worker(args):
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
#torch.backends.cudnn.deterministic = True # NOTE: test perf
args.gpus_per_node = torch.cuda.device_count()
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
args.world_size = args.gpus_per_node * args.nodes
node = int(os.environ['SLURM_NODEID'])
if node == 0:
print(args)
args.node = node
spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, args):
args.device = torch.device('cuda', local_rank)
torch.cuda.device(args.device)
args.rank = args.gpus_per_node * args.node + local_rank
init_process_group(
backend=args.dist_backend,
init_method='env://',
world_size=args.world_size,
rank=args.rank
)
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
augment=args.augment,
2019-12-02 00:53:38 +01:00
norms=args.norms,
pad_or_crop=args.pad_or_crop,
2019-11-30 21:32:45 +01:00
)
train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = DataLoader(
train_dataset,
2019-12-02 00:53:38 +01:00
batch_size=args.batches,
2019-11-30 21:32:45 +01:00
shuffle=False,
sampler=train_sampler,
2019-12-02 00:53:38 +01:00
num_workers=args.loader_workers,
2019-11-30 21:32:45 +01:00
pin_memory=True
)
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
augment=False,
2019-12-02 00:53:38 +01:00
norms=args.norms,
pad_or_crop=args.pad_or_crop,
2019-11-30 21:32:45 +01:00
)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_loader = DataLoader(
val_dataset,
2019-12-02 00:53:38 +01:00
batch_size=args.batches,
2019-11-30 21:32:45 +01:00
shuffle=False,
sampler=val_sampler,
2019-12-02 00:53:38 +01:00
num_workers=args.loader_workers,
2019-11-30 21:32:45 +01:00
pin_memory=True
)
model = UNet(args.in_channels, args.out_channels)
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device])
criterion = torch.nn.__dict__[args.criterion]()
criterion.to(args.device)
optimizer = torch.optim.__dict__[args.optimizer](
model.parameters(),
lr=args.lr,
#momentum=args.momentum,
#weight_decay=args.weight_decay
)
2019-12-09 02:58:46 +01:00
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1, verbose=True)
2019-11-30 21:32:45 +01:00
if args.load_state:
2019-12-02 00:53:38 +01:00
state = torch.load(args.load_state, map_location=args.device)
args.start_epoch = state['epoch']
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])
torch.set_rng_state(state['rng'].cpu()) # move rng state back
2019-11-30 21:32:45 +01:00
if args.rank == 0:
2019-12-02 00:53:38 +01:00
min_loss = state['min_loss']
print('checkpoint at epoch {} loaded from {}'.format(
state['epoch'], args.load_state))
del state
2019-11-30 21:32:45 +01:00
else:
args.start_epoch = 0
if args.rank == 0:
min_loss = None
torch.backends.cudnn.benchmark = True # NOTE: test perf
if args.rank == 0:
args.logger = SummaryWriter()
hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
else str(v) for k, v in vars(args).items()}
args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
for epoch in range(args.start_epoch, args.epochs):
train_sampler.set_epoch(epoch)
2019-12-03 23:40:08 +01:00
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
2019-11-30 21:32:45 +01:00
val_loss = validate(epoch, val_loader, model, criterion, args)
2019-12-09 02:58:46 +01:00
scheduler.step(val_loss)
2019-11-30 21:32:45 +01:00
if args.rank == 0:
args.logger.close()
2019-12-02 00:53:38 +01:00
state = {
2019-11-30 21:32:45 +01:00
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'scheduler' : scheduler.state_dict(),
'rng' : torch.get_rng_state(),
'min_loss': min_loss,
}
filename='checkpoint.pth'
2019-12-02 00:53:38 +01:00
torch.save(state, filename)
del state
2019-11-30 21:32:45 +01:00
if min_loss is None or val_loss < min_loss:
min_loss = val_loss
shutil.copyfile(filename, 'best_model.pth')
destroy_process_group()
2019-12-03 23:40:08 +01:00
def train(epoch, loader, model, criterion, optimizer, scheduler, args):
2019-11-30 21:32:45 +01:00
model.train()
for i, (input, target) in enumerate(loader):
input = input.to(args.device, non_blocking=True)
target = target.to(args.device, non_blocking=True)
output = model(input)
2019-12-02 00:53:38 +01:00
target = narrow_like(target, output) # FIXME pad
2019-11-30 21:32:45 +01:00
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
2019-12-09 02:58:46 +01:00
#if scheduler is not None: # for batch scheduler
#scheduler.step()
2019-12-03 23:40:08 +01:00
batch = epoch * len(loader) + i + 1
2019-11-30 21:32:45 +01:00
if batch % args.log_interval == 0:
all_reduce(loss)
loss /= args.world_size
if args.rank == 0:
args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
def validate(epoch, loader, model, criterion, args):
model.eval()
loss = 0
with torch.no_grad():
for i, (input, target) in enumerate(loader):
input = input.to(args.device, non_blocking=True)
target = target.to(args.device, non_blocking=True)
output = model(input)
2019-12-02 00:53:38 +01:00
target = narrow_like(target, output) # FIXME pad
2019-11-30 21:32:45 +01:00
loss += criterion(output, target)
all_reduce(loss)
loss /= len(loader) * args.world_size
if args.rank == 0:
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1)
2019-11-30 21:32:45 +01:00
return loss.item()