Skip to content

Commit

Permalink
Merge branch 'production' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Aug 30, 2024
2 parents e0e88bd + 4f4e0a5 commit d78e252
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions mirgecom/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
THE SOFTWARE.
"""

from typing import Type, Dict, Any
from typing import Type, Dict, Any, Tuple
import os
import logging

Expand Down Expand Up @@ -183,8 +183,7 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None:
"""
Check whether multiple ranks are running on the same GPU on each node.
Only works with CUDA devices currently due to the use of the
PCI_DOMAIN_ID_NV extension.
Only works with CUDA or AMD devices currently.
"""
if not actx_class_is_distributed(type(actx)):
return
Expand All @@ -201,29 +200,37 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None:

dev = actx.queue.device

# This check only works with Nvidia GPUs
from pyopencl.characterize import nv_compute_capability
if nv_compute_capability(dev) is None:
# Only check GPU devices
if not (dev.type & cl.device_type.GPU):
return

from mirgecom.mpi import shared_split_comm_world

with shared_split_comm_world() as node_comm:
from pyopencl.characterize import nv_compute_capability
if nv_compute_capability(dev) is not None:
try:
domain_id = hex(dev.pci_domain_id_nv)
except (cl._cl.LogicError, AttributeError):
from warnings import warn
warn("Cannot detect whether multiple ranks are running on the"
" same GPU because it requires Nvidia GPUs running with"
" pyopencl>2021.1.1 and (Nvidia CL or pocl>1.6).")
return

node_rank = node_comm.Get_rank()
raise

bus_id = hex(dev.pci_bus_id_nv)
slot_id = hex(dev.pci_slot_id_nv)
dev_id: Tuple[Any, ...] = (domain_id, bus_id, slot_id)

dev_id = (domain_id, bus_id, slot_id)
elif dev.platform.vendor.startswith("Advanced Micro"):
dev_id = (dev.topology_amd.bus,)
else:
from warnings import warn
warn("Cannot detect whether multiple ranks are running on the"
" same GPU.")
return

from mirgecom.mpi import shared_split_comm_world

with shared_split_comm_world() as node_comm:
node_rank = node_comm.Get_rank()

dev_ids = node_comm.gather(dev_id, root=0)

Expand Down

0 comments on commit d78e252

Please sign in to comment.