Fix bug when env SLURM_STEP_NUM_NODES is not defined
This commit is contained in:
parent
5cb4a1bbae
commit
a46746287a
@ -23,9 +23,14 @@ from .state import load_model_state_dict
|
|||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
|
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
||||||
|
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
||||||
|
elif 'SLURM_JOB_NUM_NODES' in os.environ:
|
||||||
|
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
||||||
|
else:
|
||||||
|
raise KeyError('missing node counts in slurm env')
|
||||||
args.gpus_per_node = torch.cuda.device_count()
|
args.gpus_per_node = torch.cuda.device_count()
|
||||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
args.world_size = args.nodes * args.gpus_per_node
|
||||||
args.world_size = args.gpus_per_node * args.nodes
|
|
||||||
|
|
||||||
node = int(os.environ['SLURM_NODEID'])
|
node = int(os.environ['SLURM_NODEID'])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user