Add runtime address and port determination and share them via file

Together with slurm step node counts, make it possible to launch
multiple training in one job
This commit is contained in:
Yin Li 2020-02-13 19:56:54 -06:00
parent 1818e11265
commit b67079bf72
5 changed files with 60 additions and 43 deletions

View File

@ -1,5 +1,8 @@
import os import os
import shutil import shutil
import socket
import time
import sys
from pprint import pprint from pprint import pprint
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -19,48 +22,28 @@ from .models import (narrow_like,
from .state import load_model_state_dict from .state import load_model_state_dict
def set_runtime_default_args(args):
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.adv = args.adv_model is not None
if args.adv:
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
def node_worker(args): def node_worker(args):
set_runtime_default_args(args)
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
#torch.backends.cudnn.deterministic = True # NOTE: test perf
args.gpus_per_node = torch.cuda.device_count() args.gpus_per_node = torch.cuda.device_count()
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES']) args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
args.world_size = args.gpus_per_node * args.nodes args.world_size = args.gpus_per_node * args.nodes
node = int(os.environ['SLURM_NODEID']) node = int(os.environ['SLURM_NODEID'])
if node == 0:
pprint(vars(args))
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node) spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
def gpu_worker(local_rank, node, args): def gpu_worker(local_rank, node, args):
set_runtime_default_args(args)
device = torch.device('cuda', local_rank) device = torch.device('cuda', local_rank)
torch.cuda.device(device) torch.cuda.device(device)
torch.manual_seed(args.seed)
#torch.backends.cudnn.deterministic = True # NOTE: test perf
rank = args.gpus_per_node * node + local_rank rank = args.gpus_per_node * node + local_rank
dist.init_process_group( dist_init(rank, args)
backend=args.dist_backend,
init_method='env://',
world_size=args.world_size,
rank=rank,
)
train_dataset = FieldDataset( train_dataset = FieldDataset(
in_patterns=args.train_in_patterns, in_patterns=args.train_in_patterns,
@ -210,6 +193,10 @@ def gpu_worker(local_rank, node, args):
if rank == 0: if rank == 0:
logger = SummaryWriter() logger = SummaryWriter()
if rank == 0:
pprint(vars(args))
sys.stdout.flush()
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
if not args.div_data: if not args.div_data:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
@ -449,6 +436,52 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
return epoch_loss return epoch_loss
def set_runtime_default_args(args):
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.adv = args.adv_model is not None
if args.adv:
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
def dist_init(rank, args):
dist_file = 'dist_addr'
if rank == 0:
addr = socket.gethostname()
with socket.socket() as s:
s.bind((addr, 0))
_, port = s.getsockname()
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
with open(dist_file, mode='w') as f:
f.write(args.dist_addr)
if rank != 0:
while not os.path.exists(dist_file):
time.sleep(1)
with open(dist_file, mode='r') as f:
args.dist_addr = f.read()
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_addr,
world_size=args.world_size,
rank=rank,
)
if rank == 0:
os.remove(dist_file)
def set_requires_grad(module, requires_grad=False): def set_requires_grad(module, requires_grad=False):
for param in module.parameters(): for param in module.parameters():
param.requires_grad = requires_grad param.requires_grad = requires_grad

View File

@ -18,10 +18,6 @@ module load gcc python3
#source $HOME/anaconda3/bin/activate torch #source $HOME/anaconda3/bin/activate torch
export MASTER_ADDR=$HOSTNAME
export MASTER_PORT=60606
data_root_dir="/mnt/ceph/users/yinli/Quijote" data_root_dir="/mnt/ceph/users/yinli/Quijote"
in_dir="linear" in_dir="linear"

View File

@ -18,10 +18,6 @@ module load gcc python3
#source $HOME/anaconda3/bin/activate torch #source $HOME/anaconda3/bin/activate torch
export MASTER_ADDR=$HOSTNAME
export MASTER_PORT=60606
data_root_dir="/mnt/ceph/users/yinli/Quijote" data_root_dir="/mnt/ceph/users/yinli/Quijote"
in_dir="linear" in_dir="linear"

View File

@ -19,10 +19,6 @@ hostname; pwd; date
source $HOME/anaconda3/bin/activate source $HOME/anaconda3/bin/activate
export MASTER_ADDR=$HOSTNAME
export MASTER_PORT=60606
data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train" data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train"
in_dir="low-resl" in_dir="low-resl"

View File

@ -18,10 +18,6 @@ module load gcc python3
#source $HOME/anaconda3/bin/activate torch #source $HOME/anaconda3/bin/activate torch
export MASTER_ADDR=$HOSTNAME
export MASTER_PORT=60606
data_root_dir="/mnt/ceph/users/yinli/Quijote" data_root_dir="/mnt/ceph/users/yinli/Quijote"
in_dir="linear" in_dir="linear"