From b67079bf728c126ff9a6ab547b26fc5ffc607038 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 13 Feb 2020 19:56:54 -0600 Subject: [PATCH] Add runtime address and port determination and share them via file Together with slurm step node counts, make it possible to launch multiple training in one job --- map2map/train.py | 87 +++++++++++++++++++++++++++++-------------- scripts/dis2den.slurm | 4 -- scripts/dis2dis.slurm | 4 -- scripts/srsgan.slurm | 4 -- scripts/vel2vel.slurm | 4 -- 5 files changed, 60 insertions(+), 43 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index d21431c..6a96cda 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -1,5 +1,8 @@ import os import shutil +import socket +import time +import sys from pprint import pprint import torch import torch.nn.functional as F @@ -19,48 +22,28 @@ from .models import (narrow_like, from .state import load_model_state_dict -def set_runtime_default_args(args): - args.val = args.val_in_patterns is not None and \ - args.val_tgt_patterns is not None - - args.adv = args.adv_model is not None - - if args.adv: - if args.adv_lr is None: - args.adv_lr = args.lr - if args.adv_weight_decay is None: - args.adv_weight_decay = args.weight_decay - - def node_worker(args): - set_runtime_default_args(args) - - torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker? - #torch.backends.cudnn.deterministic = True # NOTE: test perf - args.gpus_per_node = torch.cuda.device_count() - args.nodes = int(os.environ['SLURM_JOB_NUM_NODES']) + args.nodes = int(os.environ['SLURM_STEP_NUM_NODES']) args.world_size = args.gpus_per_node * args.nodes node = int(os.environ['SLURM_NODEID']) - if node == 0: - pprint(vars(args)) spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node) def gpu_worker(local_rank, node, args): + set_runtime_default_args(args) + device = torch.device('cuda', local_rank) torch.cuda.device(device) + torch.manual_seed(args.seed) + #torch.backends.cudnn.deterministic = True # NOTE: test perf + rank = args.gpus_per_node * node + local_rank - dist.init_process_group( - backend=args.dist_backend, - init_method='env://', - world_size=args.world_size, - rank=rank, - ) + dist_init(rank, args) train_dataset = FieldDataset( in_patterns=args.train_in_patterns, @@ -210,6 +193,10 @@ def gpu_worker(local_rank, node, args): if rank == 0: logger = SummaryWriter() + if rank == 0: + pprint(vars(args)) + sys.stdout.flush() + for epoch in range(start_epoch, args.epochs): if not args.div_data: train_sampler.set_epoch(epoch) @@ -449,6 +436,52 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, return epoch_loss +def set_runtime_default_args(args): + args.val = args.val_in_patterns is not None and \ + args.val_tgt_patterns is not None + + args.adv = args.adv_model is not None + + if args.adv: + if args.adv_lr is None: + args.adv_lr = args.lr + if args.adv_weight_decay is None: + args.adv_weight_decay = args.weight_decay + + +def dist_init(rank, args): + dist_file = 'dist_addr' + + if rank == 0: + addr = socket.gethostname() + + with socket.socket() as s: + s.bind((addr, 0)) + _, port = s.getsockname() + + args.dist_addr = 'tcp://{}:{}'.format(addr, port) + + with open(dist_file, mode='w') as f: + f.write(args.dist_addr) + + if rank != 0: + while not os.path.exists(dist_file): + time.sleep(1) + + with open(dist_file, mode='r') as f: + args.dist_addr = f.read() + + dist.init_process_group( + backend=args.dist_backend, + init_method=args.dist_addr, + world_size=args.world_size, + rank=rank, + ) + + if rank == 0: + os.remove(dist_file) + + def set_requires_grad(module, requires_grad=False): for param in module.parameters(): param.requires_grad = requires_grad diff --git a/scripts/dis2den.slurm b/scripts/dis2den.slurm index 801255c..cbb89d1 100644 --- a/scripts/dis2den.slurm +++ b/scripts/dis2den.slurm @@ -18,10 +18,6 @@ module load gcc python3 #source $HOME/anaconda3/bin/activate torch -export MASTER_ADDR=$HOSTNAME -export MASTER_PORT=60606 - - data_root_dir="/mnt/ceph/users/yinli/Quijote" in_dir="linear" diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm index e5f2dbc..0a473c3 100644 --- a/scripts/dis2dis.slurm +++ b/scripts/dis2dis.slurm @@ -18,10 +18,6 @@ module load gcc python3 #source $HOME/anaconda3/bin/activate torch -export MASTER_ADDR=$HOSTNAME -export MASTER_PORT=60606 - - data_root_dir="/mnt/ceph/users/yinli/Quijote" in_dir="linear" diff --git a/scripts/srsgan.slurm b/scripts/srsgan.slurm index dab061f..8557791 100644 --- a/scripts/srsgan.slurm +++ b/scripts/srsgan.slurm @@ -19,10 +19,6 @@ hostname; pwd; date source $HOME/anaconda3/bin/activate -export MASTER_ADDR=$HOSTNAME -export MASTER_PORT=60606 - - data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train" in_dir="low-resl" diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm index 10162c5..62347ea 100644 --- a/scripts/vel2vel.slurm +++ b/scripts/vel2vel.slurm @@ -18,10 +18,6 @@ module load gcc python3 #source $HOME/anaconda3/bin/activate torch -export MASTER_ADDR=$HOSTNAME -export MASTER_PORT=60606 - - data_root_dir="/mnt/ceph/users/yinli/Quijote" in_dir="linear"