Skip to content

Commit

Permalink
Merge pull request #1589 from rapidsai/branch-23.06
Browse files Browse the repository at this point in the history
[RELEASE] Hotfix v23.06 raft
  • Loading branch information
raydouglass authored Jun 12, 2023
2 parents c931b61 + 6c81a41 commit af1515d
Showing 1 changed file with 74 additions and 5 deletions.
79 changes: 74 additions & 5 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
#

import logging
import os
import time
import uuid
import warnings
from collections import OrderedDict
from collections import Counter, OrderedDict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index

from pylibraft.common.handle import Handle

Expand Down Expand Up @@ -155,7 +157,7 @@ def worker_info(self, workers):
Builds a dictionary of { (worker_address, worker_port) :
(worker_rank, worker_port ) }
"""
ranks = _func_worker_ranks(workers)
ranks = _func_worker_ranks(self.client)
ports = (
_func_ucp_ports(self.client, workers) if self.comms_p2p else None
)
Expand Down Expand Up @@ -686,8 +688,75 @@ def _func_ucp_ports(client, workers):
return client.run(_func_ucp_listener_port, workers=workers)


def _func_worker_ranks(workers):
def _func_worker_ranks(client):
"""
Builds a dictionary of { (worker_address, worker_port) : worker_rank }
For each worker connected to the client,
compute a global rank which is the sum
of the NVML device index and the worker rank offset.
Parameters
----------
client (object): Dask client object.
"""
ranks = client.run(_get_nvml_device_index)
worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks]
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)


def _get_nvml_device_index():
"""
Return NVML device index based on environment variable
'CUDA_VISIBLE_DEVICES'.
"""
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
return nvml_device_index(0, CUDA_VISIBLE_DEVICES)


def _get_worker_ip(worker_address):
"""
Extract the worker IP address from the worker address string.
Parameters
----------
worker_address (str): Full address string of the worker
"""
return ":".join(worker_address.split(":")[0:2])


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
their occurrences in the worker_ips list. The cumulative count serves as
the rank offset.
Parameters
----------
worker_ips (list): List of worker IP addresses.
"""
worker_count_dict = Counter(worker_ips)
worker_offset_dict = {}
current_offset = 0
for worker_ip, worker_count in worker_count_dict.items():
worker_offset_dict[worker_ip] = current_offset
current_offset += worker_count
return worker_offset_dict


def _append_rank_offset(rank_dict, worker_ip_offset_dict):
"""
For each worker address in the rank dictionary, add the
corresponding worker offset from the worker_ip_offset_dict
to the rank value.
Parameters
----------
rank_dict (dict): Dictionary of worker addresses mapped to their ranks.
worker_ip_offset_dict (dict): Dictionary of worker IP addresses
mapped to their offsets.
"""
return dict(list(zip(workers, range(len(workers)))))
for worker_ip, worker_offset in worker_ip_offset_dict.items():
for worker_address in rank_dict:
if worker_ip in worker_address:
rank_dict[worker_address] += worker_offset
return rank_dict

0 comments on commit af1515d

Please sign in to comment.