Skip to content

Commit

Permalink
add AMD GPU oversubscription check
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Aug 26, 2024
1 parent 0d680df commit 23dac96
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions mirgecom/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None:
"""
Check whether multiple ranks are running on the same GPU on each node.
Only works with CUDA devices currently due to the use of the
PCI_DOMAIN_ID_NV extension.
Only works with CUDA or AMD devices currently.
"""
from mpi4py import MPI
import pyopencl as cl
Expand All @@ -188,30 +187,34 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None:

dev = actx.queue.device

# This check only works with Nvidia GPUs
from pyopencl.characterize import nv_compute_capability
if nv_compute_capability(dev) is None:
return

from mirgecom.mpi import shared_split_comm_world

with shared_split_comm_world() as node_comm:
if nv_compute_capability(dev) is not None:
try:
domain_id = hex(dev.pci_domain_id_nv)
except (cl._cl.LogicError, AttributeError):
from warnings import warn
warn("Cannot detect whether multiple ranks are running on the"
" same GPU because it requires Nvidia GPUs running with"
" pyopencl>2021.1.1 and (Nvidia CL or pocl>1.6).")
return

node_rank = node_comm.Get_rank()
" same GPU because it requires Nvidia GPUs running with"
" pyopencl>2021.1.1 and (Nvidia CL or pocl>1.6).")
raise

bus_id = hex(dev.pci_bus_id_nv)
slot_id = hex(dev.pci_slot_id_nv)

dev_id = (domain_id, bus_id, slot_id)

elif dev.platform.vendor.startswith("Advanced Micro"):
dev_id = (dev.topology_amd.bus,)
else:
from warnings import warn
warn("Cannot detect whether multiple ranks are running on the"
" same GPU.")
return

from mirgecom.mpi import shared_split_comm_world

with shared_split_comm_world() as node_comm:
node_rank = node_comm.Get_rank()

dev_ids = node_comm.gather(dev_id, root=0)

if node_rank == 0:
Expand Down

0 comments on commit 23dac96

Please sign in to comment.