Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AMD GPU oversubscription check #1059

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions mirgecom/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
THE SOFTWARE.
"""

from typing import Type, Dict, Any
from typing import Type, Dict, Any, Tuple
import os
import logging

Expand Down Expand Up @@ -183,8 +183,7 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None:
"""
Check whether multiple ranks are running on the same GPU on each node.

Only works with CUDA devices currently due to the use of the
PCI_DOMAIN_ID_NV extension.
Only works with CUDA or AMD devices currently.
"""
if not actx_class_is_distributed(type(actx)):
return
Expand All @@ -201,29 +200,37 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None:

dev = actx.queue.device

# This check only works with Nvidia GPUs
from pyopencl.characterize import nv_compute_capability
if nv_compute_capability(dev) is None:
# Only check GPU devices
if not (dev.type & cl.device_type.GPU):
return

from mirgecom.mpi import shared_split_comm_world

with shared_split_comm_world() as node_comm:
from pyopencl.characterize import nv_compute_capability
if nv_compute_capability(dev) is not None:
try:
domain_id = hex(dev.pci_domain_id_nv)
except (cl._cl.LogicError, AttributeError):
from warnings import warn
warn("Cannot detect whether multiple ranks are running on the"
" same GPU because it requires Nvidia GPUs running with"
" pyopencl>2021.1.1 and (Nvidia CL or pocl>1.6).")
return

node_rank = node_comm.Get_rank()
raise

bus_id = hex(dev.pci_bus_id_nv)
slot_id = hex(dev.pci_slot_id_nv)
dev_id: Tuple[Any, ...] = (domain_id, bus_id, slot_id)

dev_id = (domain_id, bus_id, slot_id)
elif dev.platform.vendor.startswith("Advanced Micro"):
dev_id = (dev.topology_amd.bus,)
else:
from warnings import warn
warn("Cannot detect whether multiple ranks are running on the"
" same GPU.")
return

from mirgecom.mpi import shared_split_comm_world

with shared_split_comm_world() as node_comm:
node_rank = node_comm.Get_rank()

dev_ids = node_comm.gather(dev_id, root=0)

Expand Down