diff --git a/map2map/train.py b/map2map/train.py index 75ae144..871e5b3 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -42,8 +42,12 @@ def node_worker(args): def gpu_worker(local_rank, node, args): - device = torch.device('cuda', local_rank) - torch.cuda.device(device) + #device = torch.device('cuda', local_rank) + #torch.cuda.device(device) # env var recommended over this + + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) + device = torch.device('cuda', 0) rank = args.gpus_per_node * node + local_rank