Skip to content

Commit

Permalink
[distributed][misc] add specialized method for cuda platform (#7249)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 7, 2024
1 parent 66d617e commit 639159b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 53 deletions.
8 changes: 6 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless

try:
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
Expand Down Expand Up @@ -113,7 +114,10 @@ def __init__(self,
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = is_full_nvlink(physical_device_ids)
assert current_platform.is_cuda()
from vllm.platforms.cuda import CudaPlatform
cuda_platform: CudaPlatform = current_platform
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
Expand Down
37 changes: 36 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@

import os
from functools import lru_cache, wraps
from typing import Tuple
from typing import List, Tuple

import pynvml

from vllm.logger import init_logger

from .interface import Platform, PlatformEnum

logger = init_logger(__name__)

# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA


def with_nvml_context(fn):

Expand Down Expand Up @@ -47,3 +56,29 @@ class CudaPlatform(Platform):
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)

@staticmethod
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
50 changes: 0 additions & 50 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,56 +1034,6 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA

try:
import pynvml
except ImportError:
# For non-NV devices
pynvml = None


def with_nvml_context(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
if pynvml is not None:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
if pynvml is not None:
pynvml.nvmlShutdown()

return wrapper


@with_nvml_context
def is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True


#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):

Expand Down

0 comments on commit 639159b

Please sign in to comment.