Add runtime address and port determination and share them via file
Together with slurm step node counts, make it possible to launch multiple training in one job
This commit is contained in:
parent
1818e11265
commit
b67079bf72
5 changed files with 60 additions and 43 deletions
|
@ -1,5 +1,8 @@
|
|||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import time
|
||||
import sys
|
||||
from pprint import pprint
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -19,48 +22,28 @@ from .models import (narrow_like,
|
|||
from .state import load_model_state_dict
|
||||
|
||||
|
||||
def set_runtime_default_args(args):
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
|
||||
args.adv = args.adv_model is not None
|
||||
|
||||
if args.adv:
|
||||
if args.adv_lr is None:
|
||||
args.adv_lr = args.lr
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
set_runtime_default_args(args)
|
||||
|
||||
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
||||
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
||||
|
||||
args.gpus_per_node = torch.cuda.device_count()
|
||||
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
||||
args.world_size = args.gpus_per_node * args.nodes
|
||||
|
||||
node = int(os.environ['SLURM_NODEID'])
|
||||
if node == 0:
|
||||
pprint(vars(args))
|
||||
|
||||
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
|
||||
|
||||
|
||||
def gpu_worker(local_rank, node, args):
|
||||
set_runtime_default_args(args)
|
||||
|
||||
device = torch.device('cuda', local_rank)
|
||||
torch.cuda.device(device)
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
||||
|
||||
rank = args.gpus_per_node * node + local_rank
|
||||
|
||||
dist.init_process_group(
|
||||
backend=args.dist_backend,
|
||||
init_method='env://',
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
)
|
||||
dist_init(rank, args)
|
||||
|
||||
train_dataset = FieldDataset(
|
||||
in_patterns=args.train_in_patterns,
|
||||
|
@ -210,6 +193,10 @@ def gpu_worker(local_rank, node, args):
|
|||
if rank == 0:
|
||||
logger = SummaryWriter()
|
||||
|
||||
if rank == 0:
|
||||
pprint(vars(args))
|
||||
sys.stdout.flush()
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
if not args.div_data:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
@ -449,6 +436,52 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||
return epoch_loss
|
||||
|
||||
|
||||
def set_runtime_default_args(args):
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
|
||||
args.adv = args.adv_model is not None
|
||||
|
||||
if args.adv:
|
||||
if args.adv_lr is None:
|
||||
args.adv_lr = args.lr
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
|
||||
def dist_init(rank, args):
|
||||
dist_file = 'dist_addr'
|
||||
|
||||
if rank == 0:
|
||||
addr = socket.gethostname()
|
||||
|
||||
with socket.socket() as s:
|
||||
s.bind((addr, 0))
|
||||
_, port = s.getsockname()
|
||||
|
||||
args.dist_addr = 'tcp://{}:{}'.format(addr, port)
|
||||
|
||||
with open(dist_file, mode='w') as f:
|
||||
f.write(args.dist_addr)
|
||||
|
||||
if rank != 0:
|
||||
while not os.path.exists(dist_file):
|
||||
time.sleep(1)
|
||||
|
||||
with open(dist_file, mode='r') as f:
|
||||
args.dist_addr = f.read()
|
||||
|
||||
dist.init_process_group(
|
||||
backend=args.dist_backend,
|
||||
init_method=args.dist_addr,
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
os.remove(dist_file)
|
||||
|
||||
|
||||
def set_requires_grad(module, requires_grad=False):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
|
|
@ -18,10 +18,6 @@ module load gcc python3
|
|||
#source $HOME/anaconda3/bin/activate torch
|
||||
|
||||
|
||||
export MASTER_ADDR=$HOSTNAME
|
||||
export MASTER_PORT=60606
|
||||
|
||||
|
||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||
|
||||
in_dir="linear"
|
||||
|
|
|
@ -18,10 +18,6 @@ module load gcc python3
|
|||
#source $HOME/anaconda3/bin/activate torch
|
||||
|
||||
|
||||
export MASTER_ADDR=$HOSTNAME
|
||||
export MASTER_PORT=60606
|
||||
|
||||
|
||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||
|
||||
in_dir="linear"
|
||||
|
|
|
@ -19,10 +19,6 @@ hostname; pwd; date
|
|||
source $HOME/anaconda3/bin/activate
|
||||
|
||||
|
||||
export MASTER_ADDR=$HOSTNAME
|
||||
export MASTER_PORT=60606
|
||||
|
||||
|
||||
data_root_dir="/scratch1/06431/yueyingn/dmo-50MPC-train"
|
||||
|
||||
in_dir="low-resl"
|
||||
|
|
|
@ -18,10 +18,6 @@ module load gcc python3
|
|||
#source $HOME/anaconda3/bin/activate torch
|
||||
|
||||
|
||||
export MASTER_ADDR=$HOSTNAME
|
||||
export MASTER_PORT=60606
|
||||
|
||||
|
||||
data_root_dir="/mnt/ceph/users/yinli/Quijote"
|
||||
|
||||
in_dir="linear"
|
||||
|
|
Loading…
Reference in a new issue