2019-11-30 21:32:45 +01:00
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import torch
|
|
|
|
from torch.multiprocessing import spawn
|
|
|
|
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
|
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from .data import FieldDataset
|
|
|
|
from .models import UNet, narrow_like
|
|
|
|
|
|
|
|
|
|
|
|
def node_worker(args):
|
|
|
|
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
|
|
|
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
|
|
|
|
|
|
|
args.gpus_per_node = torch.cuda.device_count()
|
|
|
|
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
|
|
|
args.world_size = args.gpus_per_node * args.nodes
|
|
|
|
|
|
|
|
node = int(os.environ['SLURM_NODEID'])
|
|
|
|
if node == 0:
|
|
|
|
print(args)
|
|
|
|
args.node = node
|
|
|
|
|
|
|
|
spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node)
|
|
|
|
|
|
|
|
|
|
|
|
def gpu_worker(local_rank, args):
|
|
|
|
args.device = torch.device('cuda', local_rank)
|
|
|
|
torch.cuda.device(args.device)
|
|
|
|
|
|
|
|
args.rank = args.gpus_per_node * args.node + local_rank
|
|
|
|
|
|
|
|
init_process_group(
|
|
|
|
backend=args.dist_backend,
|
|
|
|
init_method='env://',
|
|
|
|
world_size=args.world_size,
|
|
|
|
rank=args.rank
|
|
|
|
)
|
|
|
|
|
|
|
|
train_dataset = FieldDataset(
|
|
|
|
in_patterns=args.train_in_patterns,
|
|
|
|
tgt_patterns=args.train_tgt_patterns,
|
|
|
|
augment=args.augment,
|
2019-12-02 00:53:38 +01:00
|
|
|
norms=args.norms,
|
|
|
|
pad_or_crop=args.pad_or_crop,
|
2019-11-30 21:32:45 +01:00
|
|
|
)
|
|
|
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
|
|
|
train_loader = DataLoader(
|
|
|
|
train_dataset,
|
2019-12-02 00:53:38 +01:00
|
|
|
batch_size=args.batches,
|
2019-11-30 21:32:45 +01:00
|
|
|
shuffle=False,
|
|
|
|
sampler=train_sampler,
|
2019-12-02 00:53:38 +01:00
|
|
|
num_workers=args.loader_workers,
|
2019-11-30 21:32:45 +01:00
|
|
|
pin_memory=True
|
|
|
|
)
|
|
|
|
|
|
|
|
val_dataset = FieldDataset(
|
|
|
|
in_patterns=args.val_in_patterns,
|
|
|
|
tgt_patterns=args.val_tgt_patterns,
|
|
|
|
augment=False,
|
2019-12-02 00:53:38 +01:00
|
|
|
norms=args.norms,
|
|
|
|
pad_or_crop=args.pad_or_crop,
|
2019-11-30 21:32:45 +01:00
|
|
|
)
|
|
|
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
|
|
|
val_loader = DataLoader(
|
|
|
|
val_dataset,
|
2019-12-02 00:53:38 +01:00
|
|
|
batch_size=args.batches,
|
2019-11-30 21:32:45 +01:00
|
|
|
shuffle=False,
|
|
|
|
sampler=val_sampler,
|
2019-12-02 00:53:38 +01:00
|
|
|
num_workers=args.loader_workers,
|
2019-11-30 21:32:45 +01:00
|
|
|
pin_memory=True
|
|
|
|
)
|
|
|
|
|
|
|
|
model = UNet(args.in_channels, args.out_channels)
|
|
|
|
model.to(args.device)
|
|
|
|
model = DistributedDataParallel(model, device_ids=[args.device])
|
|
|
|
|
|
|
|
criterion = torch.nn.__dict__[args.criterion]()
|
|
|
|
criterion.to(args.device)
|
|
|
|
|
|
|
|
optimizer = torch.optim.__dict__[args.optimizer](
|
|
|
|
model.parameters(),
|
|
|
|
lr=args.lr,
|
|
|
|
#momentum=args.momentum,
|
|
|
|
#weight_decay=args.weight_decay
|
|
|
|
)
|
2019-12-09 02:58:46 +01:00
|
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
|
|
|
factor=0.1, verbose=True)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
if args.load_state:
|
2019-12-02 00:53:38 +01:00
|
|
|
state = torch.load(args.load_state, map_location=args.device)
|
|
|
|
args.start_epoch = state['epoch']
|
|
|
|
model.load_state_dict(state['model'])
|
|
|
|
optimizer.load_state_dict(state['optimizer'])
|
|
|
|
scheduler.load_state_dict(state['scheduler'])
|
|
|
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
2019-11-30 21:32:45 +01:00
|
|
|
if args.rank == 0:
|
2019-12-02 00:53:38 +01:00
|
|
|
min_loss = state['min_loss']
|
|
|
|
print('checkpoint at epoch {} loaded from {}'.format(
|
|
|
|
state['epoch'], args.load_state))
|
|
|
|
del state
|
2019-11-30 21:32:45 +01:00
|
|
|
else:
|
|
|
|
args.start_epoch = 0
|
|
|
|
if args.rank == 0:
|
|
|
|
min_loss = None
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
|
|
|
|
|
|
|
if args.rank == 0:
|
|
|
|
args.logger = SummaryWriter()
|
|
|
|
hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
|
|
|
else str(v) for k, v in vars(args).items()}
|
|
|
|
args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
|
|
|
|
|
|
|
for epoch in range(args.start_epoch, args.epochs):
|
|
|
|
train_sampler.set_epoch(epoch)
|
2019-12-03 23:40:08 +01:00
|
|
|
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
val_loss = validate(epoch, val_loader, model, criterion, args)
|
|
|
|
|
2019-12-09 02:58:46 +01:00
|
|
|
scheduler.step(val_loss)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
if args.rank == 0:
|
|
|
|
args.logger.close()
|
|
|
|
|
2019-12-02 00:53:38 +01:00
|
|
|
state = {
|
2019-11-30 21:32:45 +01:00
|
|
|
'epoch': epoch + 1,
|
|
|
|
'model': model.state_dict(),
|
|
|
|
'optimizer' : optimizer.state_dict(),
|
|
|
|
'scheduler' : scheduler.state_dict(),
|
|
|
|
'rng' : torch.get_rng_state(),
|
|
|
|
'min_loss': min_loss,
|
|
|
|
}
|
|
|
|
filename='checkpoint.pth'
|
2019-12-02 00:53:38 +01:00
|
|
|
torch.save(state, filename)
|
|
|
|
del state
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
if min_loss is None or val_loss < min_loss:
|
|
|
|
min_loss = val_loss
|
|
|
|
shutil.copyfile(filename, 'best_model.pth')
|
|
|
|
|
|
|
|
destroy_process_group()
|
|
|
|
|
|
|
|
|
2019-12-03 23:40:08 +01:00
|
|
|
def train(epoch, loader, model, criterion, optimizer, scheduler, args):
|
2019-11-30 21:32:45 +01:00
|
|
|
model.train()
|
|
|
|
|
|
|
|
for i, (input, target) in enumerate(loader):
|
|
|
|
input = input.to(args.device, non_blocking=True)
|
|
|
|
target = target.to(args.device, non_blocking=True)
|
|
|
|
|
|
|
|
output = model(input)
|
2019-12-02 00:53:38 +01:00
|
|
|
target = narrow_like(target, output) # FIXME pad
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
loss = criterion(output, target)
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
2019-12-09 02:58:46 +01:00
|
|
|
#if scheduler is not None: # for batch scheduler
|
|
|
|
#scheduler.step()
|
2019-12-03 23:40:08 +01:00
|
|
|
|
2019-12-01 04:12:47 +01:00
|
|
|
batch = epoch * len(loader) + i + 1
|
2019-11-30 21:32:45 +01:00
|
|
|
if batch % args.log_interval == 0:
|
|
|
|
all_reduce(loss)
|
|
|
|
loss /= args.world_size
|
|
|
|
if args.rank == 0:
|
|
|
|
args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
|
|
|
|
|
|
|
|
|
|
|
|
def validate(epoch, loader, model, criterion, args):
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
loss = 0
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
for i, (input, target) in enumerate(loader):
|
|
|
|
input = input.to(args.device, non_blocking=True)
|
|
|
|
target = target.to(args.device, non_blocking=True)
|
|
|
|
|
|
|
|
output = model(input)
|
2019-12-02 00:53:38 +01:00
|
|
|
target = narrow_like(target, output) # FIXME pad
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
loss += criterion(output, target)
|
|
|
|
|
|
|
|
all_reduce(loss)
|
|
|
|
loss /= len(loader) * args.world_size
|
|
|
|
if args.rank == 0:
|
2019-12-01 04:12:47 +01:00
|
|
|
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
return loss.item()
|