Fix bug when env SLURM_STEP_NUM_NODES is not defined
This commit is contained in:
parent
5cb4a1bbae
commit
a46746287a
@ -23,9 +23,14 @@ from .state import load_model_state_dict
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
args.gpus_per_node = torch.cuda.device_count()
|
||||
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
||||
args.world_size = args.gpus_per_node * args.nodes
|
||||
elif 'SLURM_JOB_NUM_NODES' in os.environ:
|
||||
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
||||
else:
|
||||
raise KeyError('missing node counts in slurm env')
|
||||
args.gpus_per_node = torch.cuda.device_count()
|
||||
args.world_size = args.nodes * args.gpus_per_node
|
||||
|
||||
node = int(os.environ['SLURM_NODEID'])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user