diff --git a/mirgecom/array_context.py b/mirgecom/array_context.py index 0182b693b..3f20ab885 100644 --- a/mirgecom/array_context.py +++ b/mirgecom/array_context.py @@ -33,7 +33,7 @@ THE SOFTWARE. """ -from typing import Type, Dict, Any +from typing import Type, Dict, Any, Tuple import os import logging @@ -183,8 +183,7 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None: """ Check whether multiple ranks are running on the same GPU on each node. - Only works with CUDA devices currently due to the use of the - PCI_DOMAIN_ID_NV extension. + Only works with CUDA or AMD devices currently. """ if not actx_class_is_distributed(type(actx)): return @@ -201,14 +200,12 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None: dev = actx.queue.device - # This check only works with Nvidia GPUs - from pyopencl.characterize import nv_compute_capability - if nv_compute_capability(dev) is None: + # Only check GPU devices + if not (dev.type & cl.device_type.GPU): return - from mirgecom.mpi import shared_split_comm_world - - with shared_split_comm_world() as node_comm: + from pyopencl.characterize import nv_compute_capability + if nv_compute_capability(dev) is not None: try: domain_id = hex(dev.pci_domain_id_nv) except (cl._cl.LogicError, AttributeError): @@ -216,14 +213,24 @@ def _check_gpu_oversubscription(actx: ArrayContext) -> None: warn("Cannot detect whether multiple ranks are running on the" " same GPU because it requires Nvidia GPUs running with" " pyopencl>2021.1.1 and (Nvidia CL or pocl>1.6).") - return - - node_rank = node_comm.Get_rank() + raise bus_id = hex(dev.pci_bus_id_nv) slot_id = hex(dev.pci_slot_id_nv) + dev_id: Tuple[Any, ...] = (domain_id, bus_id, slot_id) - dev_id = (domain_id, bus_id, slot_id) + elif dev.platform.vendor.startswith("Advanced Micro"): + dev_id = (dev.topology_amd.bus,) + else: + from warnings import warn + warn("Cannot detect whether multiple ranks are running on the" + " same GPU.") + return + + from mirgecom.mpi import shared_split_comm_world + + with shared_split_comm_world() as node_comm: + node_rank = node_comm.Get_rank() dev_ids = node_comm.gather(dev_id, root=0)