Fix bug when env SLURM_STEP_NUM_NODES is not defined

This commit is contained in:
Yin Li 2020-02-14 12:32:20 -05:00
parent 5cb4a1bbae
commit a46746287a

View File

@ -23,9 +23,14 @@ from .state import load_model_state_dict
def node_worker(args): def node_worker(args):
args.gpus_per_node = torch.cuda.device_count() if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES']) args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
args.world_size = args.gpus_per_node * args.nodes elif 'SLURM_JOB_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
else:
raise KeyError('missing node counts in slurm env')
args.gpus_per_node = torch.cuda.device_count()
args.world_size = args.nodes * args.gpus_per_node
node = int(os.environ['SLURM_NODEID']) node = int(os.environ['SLURM_NODEID'])