diff --git a/ci/test_python.sh b/ci/test_python.sh index 827eb84c9..ca4140bae 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -45,7 +45,7 @@ DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT=20 \ UCXPY_IFNAME=eth0 \ UCX_WARN_UNUSED_ENV_VARS=n \ UCX_MEMTYPE_CACHE=n \ -timeout 40m pytest \ +timeout 60m pytest \ -vv \ --durations=0 \ --capture=no \ diff --git a/dask_cuda/benchmarks/local_cudf_groupby.py b/dask_cuda/benchmarks/local_cudf_groupby.py index 4e9dea94e..2f07e3df7 100644 --- a/dask_cuda/benchmarks/local_cudf_groupby.py +++ b/dask_cuda/benchmarks/local_cudf_groupby.py @@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}" ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cudf_merge.py b/dask_cuda/benchmarks/local_cudf_merge.py index f26a26ae9..ba3a9d56d 100644 --- a/dask_cuda/benchmarks/local_cudf_merge.py +++ b/dask_cuda/benchmarks/local_cudf_merge.py @@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") print_key_value(key="Frac-match", value=f"{args.frac_match}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cudf_shuffle.py b/dask_cuda/benchmarks/local_cudf_shuffle.py index 51ba48f93..a3492b664 100644 --- a/dask_cuda/benchmarks/local_cudf_shuffle.py +++ b/dask_cuda/benchmarks/local_cudf_shuffle.py @@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}" ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cupy.py b/dask_cuda/benchmarks/local_cupy.py index 1c1d12d30..22c51556f 100644 --- a/dask_cuda/benchmarks/local_cupy.py +++ b/dask_cuda/benchmarks/local_cupy.py @@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") print_key_value(key="Protocol", value=f"{args.protocol}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/local_cupy_map_overlap.py b/dask_cuda/benchmarks/local_cupy_map_overlap.py index f40318559..8250c9f9f 100644 --- a/dask_cuda/benchmarks/local_cupy_map_overlap.py +++ b/dask_cuda/benchmarks/local_cupy_map_overlap.py @@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results): ) print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}") print_key_value(key="Protocol", value=f"{args.protocol}") - if args.protocol == "ucx": + if args.protocol in ["ucx", "ucxx"]: print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}") print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}") print_key_value(key="NVLink", value=f"{args.enable_nvlink}") diff --git a/dask_cuda/benchmarks/utils.py b/dask_cuda/benchmarks/utils.py index d3ce666b2..51fae7201 100644 --- a/dask_cuda/benchmarks/utils.py +++ b/dask_cuda/benchmarks/utils.py @@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[] cluster_args.add_argument( "-p", "--protocol", - choices=["tcp", "ucx"], + choices=["tcp", "ucx", "ucxx"], default="tcp", type=str, help="The communication protocol to use.", diff --git a/dask_cuda/initialize.py b/dask_cuda/initialize.py index 0b9c92a59..571a46a55 100644 --- a/dask_cuda/initialize.py +++ b/dask_cuda/initialize.py @@ -5,7 +5,6 @@ import numba.cuda import dask -import distributed.comm.ucx from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context from .utils import get_ucx_config @@ -23,12 +22,21 @@ def _create_cuda_context_handler(): numba.cuda.current_context() -def _create_cuda_context(): +def _create_cuda_context(protocol="ucx"): + if protocol not in ["ucx", "ucxx"]: + return try: # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA # context directly from the UCX module, thus avoiding a similar warning there. try: - distributed.comm.ucx.init_once() + if protocol == "ucx": + import distributed.comm.ucx + + distributed.comm.ucx.init_once() + elif protocol == "ucxx": + import distributed_ucxx.ucxx + + distributed_ucxx.ucxx.init_once() except ModuleNotFoundError: # UCX initialization has to be delegated to Distributed, it will take care # of setting correct environment variables and importing `ucp` after that. @@ -39,20 +47,35 @@ def _create_cuda_context(): os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] ) ctx = has_cuda_context() - if ( - ctx.has_context - and not distributed.comm.ucx.cuda_context_created.has_context - ): - distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid()) + if protocol == "ucx": + if ( + ctx.has_context + and not distributed.comm.ucx.cuda_context_created.has_context + ): + distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid()) + elif protocol == "ucxx": + if ( + ctx.has_context + and not distributed_ucxx.ucxx.cuda_context_created.has_context + ): + distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid()) _create_cuda_context_handler() - if not distributed.comm.ucx.cuda_context_created.has_context: - ctx = has_cuda_context() - if ctx.has_context and ctx.device_info != cuda_visible_device: - distributed.comm.ucx._warn_cuda_context_wrong_device( - cuda_visible_device, ctx.device_info, os.getpid() - ) + if protocol == "ucx": + if not distributed.comm.ucx.cuda_context_created.has_context: + ctx = has_cuda_context() + if ctx.has_context and ctx.device_info != cuda_visible_device: + distributed.comm.ucx._warn_cuda_context_wrong_device( + cuda_visible_device, ctx.device_info, os.getpid() + ) + elif protocol == "ucxx": + if not distributed_ucxx.ucxx.cuda_context_created.has_context: + ctx = has_cuda_context() + if ctx.has_context and ctx.device_info != cuda_visible_device: + distributed_ucxx.ucxx._warn_cuda_context_wrong_device( + cuda_visible_device, ctx.device_info, os.getpid() + ) except Exception: logger.error("Unable to start CUDA Context", exc_info=True) @@ -64,6 +87,7 @@ def initialize( enable_infiniband=None, enable_nvlink=None, enable_rdmacm=None, + protocol="ucx", ): """Create CUDA context and initialize UCX-Py, depending on user parameters. @@ -118,7 +142,7 @@ def initialize( dask.config.set({"distributed.comm.ucx": ucx_config}) if create_cuda_context: - _create_cuda_context() + _create_cuda_context(protocol=protocol) @click.command() @@ -127,6 +151,12 @@ def initialize( default=False, help="Create CUDA context", ) +@click.option( + "--protocol", + default=None, + type=str, + help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.", +) @click.option( "--enable-tcp-over-ucx/--disable-tcp-over-ucx", default=False, @@ -150,10 +180,11 @@ def initialize( def dask_setup( service, create_cuda_context, + protocol, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, ): if create_cuda_context: - _create_cuda_context() + _create_cuda_context(protocol=protocol) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index d0ea92748..7a5c8c13d 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -319,8 +319,11 @@ def __init__( if enable_tcp_over_ucx or enable_infiniband or enable_nvlink: if protocol is None: protocol = "ucx" - elif protocol != "ucx": - raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'") + elif protocol not in ["ucx", "ucxx"]: + raise TypeError( + "Enabling InfiniBand or NVLink requires protocol='ucx' or " + "protocol='ucxx'" + ) self.host = kwargs.get("host", None) @@ -371,7 +374,7 @@ def __init__( ) + ["dask_cuda.initialize"] self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get( "preload_argv", [] - ) + ["--create-cuda-context"] + ) + ["--create-cuda-context", "--protocol", protocol] self.cuda_visible_devices = CUDA_VISIBLE_DEVICES self.scale(n_workers) diff --git a/dask_cuda/tests/test_dgx.py b/dask_cuda/tests/test_dgx.py index ece399d45..1fd6d0ebb 100644 --- a/dask_cuda/tests/test_dgx.py +++ b/dask_cuda/tests/test_dgx.py @@ -73,10 +73,13 @@ def test_default(): assert not p.exitcode -def _test_tcp_over_ucx(): - ucp = pytest.importorskip("ucp") +def _test_tcp_over_ucx(protocol): + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") - with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster: + with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster: with Client(cluster) as client: res = da.from_array(numpy.arange(10000), chunks=(1000,)) res = res.sum().compute() @@ -93,10 +96,17 @@ def check_ucx_options(): assert all(client.run(check_ucx_options).values()) -def test_tcp_over_ucx(): - ucp = pytest.importorskip("ucp") # NOQA: F841 +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) +def test_tcp_over_ucx(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") - p = mp.Process(target=_test_tcp_over_ucx) + p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,)) p.start() p.join() assert not p.exitcode @@ -117,9 +127,14 @@ def test_tcp_only(): assert not p.exitcode -def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm): +def _test_ucx_infiniband_nvlink( + protocol, enable_infiniband, enable_nvlink, enable_rdmacm +): cupy = pytest.importorskip("cupy") - ucp = pytest.importorskip("ucp") + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None: enable_tcp_over_ucx = None @@ -135,6 +150,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm) cm_tls_priority = ["tcp"] initialize( + protocol=protocol, enable_tcp_over_ucx=enable_tcp_over_ucx, enable_infiniband=enable_infiniband, enable_nvlink=enable_nvlink, @@ -142,6 +158,7 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm) ) with LocalCUDACluster( + protocol=protocol, interface="ib0", enable_tcp_over_ucx=enable_tcp_over_ucx, enable_infiniband=enable_infiniband, @@ -171,6 +188,7 @@ def check_ucx_options(): assert all(client.run(check_ucx_options).values()) +@pytest.mark.parametrize("protocol", ["ucx", "ucxx"]) @pytest.mark.parametrize( "params", [ @@ -185,8 +203,11 @@ def check_ucx_options(): _get_dgx_version() == DGXVersion.DGX_A100, reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(), ) -def test_ucx_infiniband_nvlink(params): - ucp = pytest.importorskip("ucp") # NOQA: F841 +def test_ucx_infiniband_nvlink(protocol, params): + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") if params["enable_infiniband"]: if not any([at.startswith("rc") for at in ucp.get_active_transports()]): @@ -195,6 +216,7 @@ def test_ucx_infiniband_nvlink(params): p = mp.Process( target=_test_ucx_infiniband_nvlink, args=( + protocol, params["enable_infiniband"], params["enable_nvlink"], params["enable_rdmacm"], diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index bd6770225..21b35e481 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -44,7 +44,7 @@ def _test_local_cluster(protocol): assert sum(c.run(my_rank, 0)) == sum(range(4)) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) def test_local_cluster(protocol): p = mp.Process(target=_test_local_cluster, args=(protocol,)) p.start() @@ -160,7 +160,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): @pytest.mark.parametrize("nworkers", [1, 2, 3]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) @pytest.mark.parametrize("_partitions", [True, False]) def test_dataframe_shuffle(backend, protocol, nworkers, _partitions): if backend == "cudf": @@ -256,7 +256,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers): @pytest.mark.parametrize("nworkers", [1, 2, 4]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) def test_dataframe_shuffle_merge(backend, protocol, nworkers): if backend == "cudf": pytest.importorskip("cudf") @@ -293,7 +293,7 @@ def _test_jit_unspill(protocol): assert_eq(got, expected) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) def test_jit_unspill(protocol): pytest.importorskip("cudf") diff --git a/dask_cuda/tests/test_from_array.py b/dask_cuda/tests/test_from_array.py index 33f27d6fe..e20afcf3e 100644 --- a/dask_cuda/tests/test_from_array.py +++ b/dask_cuda/tests/test_from_array.py @@ -5,12 +5,16 @@ from dask_cuda import LocalCUDACluster -pytest.importorskip("ucp") cupy = pytest.importorskip("cupy") -@pytest.mark.parametrize("protocol", ["ucx", "tcp"]) +@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"]) def test_ucx_from_array(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") + N = 10_000 with LocalCUDACluster(protocol=protocol) as cluster: with Client(cluster): diff --git a/dask_cuda/tests/test_initialize.py b/dask_cuda/tests/test_initialize.py index 05b72f996..a953a10c1 100644 --- a/dask_cuda/tests/test_initialize.py +++ b/dask_cuda/tests/test_initialize.py @@ -13,7 +13,6 @@ from dask_cuda.utils_test import IncreasedCloseTimeoutNanny mp = mp.get_context("spawn") # type: ignore -ucp = pytest.importorskip("ucp") # Notice, all of the following tests is executed in a new process such # that UCX options of the different tests doesn't conflict. @@ -21,11 +20,16 @@ # of UCX before retrieving the current config. -def _test_initialize_ucx_tcp(): +def _test_initialize_ucx_tcp(protocol): + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") + kwargs = {"enable_tcp_over_ucx": True} - initialize(**kwargs) + initialize(protocol=protocol, **kwargs) with LocalCluster( - protocol="ucx", + protocol=protocol, dashboard_address=None, n_workers=1, threads_per_worker=1, @@ -50,18 +54,29 @@ def check_ucx_options(): assert all(client.run(check_ucx_options).values()) -def test_initialize_ucx_tcp(): - p = mp.Process(target=_test_initialize_ucx_tcp) +@pytest.mark.parametrize("protocol", ["ucx", "ucxx"]) +def test_initialize_ucx_tcp(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") + + p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,)) p.start() p.join() assert not p.exitcode -def _test_initialize_ucx_nvlink(): +def _test_initialize_ucx_nvlink(protocol): + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") + kwargs = {"enable_nvlink": True} - initialize(**kwargs) + initialize(protocol=protocol, **kwargs) with LocalCluster( - protocol="ucx", + protocol=protocol, dashboard_address=None, n_workers=1, threads_per_worker=1, @@ -87,18 +102,29 @@ def check_ucx_options(): assert all(client.run(check_ucx_options).values()) -def test_initialize_ucx_nvlink(): - p = mp.Process(target=_test_initialize_ucx_nvlink) +@pytest.mark.parametrize("protocol", ["ucx", "ucxx"]) +def test_initialize_ucx_nvlink(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") + + p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,)) p.start() p.join() assert not p.exitcode -def _test_initialize_ucx_infiniband(): +def _test_initialize_ucx_infiniband(protocol): + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") + kwargs = {"enable_infiniband": True} - initialize(**kwargs) + initialize(protocol=protocol, **kwargs) with LocalCluster( - protocol="ucx", + protocol=protocol, dashboard_address=None, n_workers=1, threads_per_worker=1, @@ -127,17 +153,28 @@ def check_ucx_options(): @pytest.mark.skipif( "ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found" ) -def test_initialize_ucx_infiniband(): - p = mp.Process(target=_test_initialize_ucx_infiniband) +@pytest.mark.parametrize("protocol", ["ucx", "ucxx"]) +def test_initialize_ucx_infiniband(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") + + p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,)) p.start() p.join() assert not p.exitcode -def _test_initialize_ucx_all(): - initialize() +def _test_initialize_ucx_all(protocol): + if protocol == "ucx": + ucp = pytest.importorskip("ucp") + elif protocol == "ucxx": + ucp = pytest.importorskip("ucxx") + + initialize(protocol=protocol) with LocalCluster( - protocol="ucx", + protocol=protocol, dashboard_address=None, n_workers=1, threads_per_worker=1, @@ -166,8 +203,14 @@ def check_ucx_options(): assert all(client.run(check_ucx_options).values()) -def test_initialize_ucx_all(): - p = mp.Process(target=_test_initialize_ucx_all) +@pytest.mark.parametrize("protocol", ["ucx", "ucxx"]) +def test_initialize_ucx_all(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") + + p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,)) p.start() p.join() assert not p.exitcode diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py index 3298cf219..b05389e4c 100644 --- a/dask_cuda/tests/test_local_cuda_cluster.py +++ b/dask_cuda/tests/test_local_cuda_cluster.py @@ -87,23 +87,38 @@ def get_visible_devices(): } +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) @gen_test(timeout=20) -async def test_ucx_protocol(): - pytest.importorskip("ucp") +async def test_ucx_protocol(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") async with LocalCUDACluster( - protocol="ucx", asynchronous=True, data=dict + protocol=protocol, asynchronous=True, data=dict ) as cluster: assert all( - ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values() + ws.address.startswith(f"{protocol}://") + for ws in cluster.scheduler.workers.values() ) +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) @gen_test(timeout=20) -async def test_explicit_ucx_with_protocol_none(): - pytest.importorskip("ucp") +async def test_explicit_ucx_with_protocol_none(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") - initialize(enable_tcp_over_ucx=True) + initialize(protocol=protocol, enable_tcp_over_ucx=True) async with LocalCUDACluster( protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict ) as cluster: @@ -113,11 +128,18 @@ async def test_explicit_ucx_with_protocol_none(): @pytest.mark.filterwarnings("ignore:Exception ignored in") +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) @gen_test(timeout=20) -async def test_ucx_protocol_type_error(): - pytest.importorskip("ucp") +async def test_ucx_protocol_type_error(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") - initialize(enable_tcp_over_ucx=True) + initialize(protocol=protocol, enable_tcp_over_ucx=True) with pytest.raises(TypeError): async with LocalCUDACluster( protocol="tcp", enable_tcp_over_ucx=True, asynchronous=True, data=dict @@ -478,16 +500,25 @@ async def test_worker_fraction_limits(): ) -def test_print_cluster_config(capsys): +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) +def test_print_cluster_config(capsys, protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") + pytest.importorskip("rich") with LocalCUDACluster( - n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol="ucx" + n_workers=1, device_memory_limit="1B", jit_unspill=True, protocol=protocol ) as cluster: with Client(cluster) as client: print_cluster_config(client) captured = capsys.readouterr() assert "Dask Cluster Configuration" in captured.out - assert "ucx" in captured.out + assert protocol in captured.out assert "1 B" in captured.out assert "[plugin]" in captured.out diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 8de56a5c5..7614219bf 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -400,10 +400,14 @@ def _pxy_deserialize(self): @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)]) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) @gen_test(timeout=120) async def test_communicating_proxy_objects(protocol, send_serializers): """Testing serialization of cuDF dataframe when communicating""" + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") cudf = pytest.importorskip("cudf") def task(x): @@ -412,7 +416,7 @@ def task(x): serializers_used = x._pxy_get().serializer # Check that `x` is serialized with the expected serializers - if protocol == "ucx": + if protocol in ["ucx", "ucxx"]: if send_serializers is None: assert serializers_used == "cuda" else: @@ -443,11 +447,15 @@ def task(x): await client.submit(task, df) -@pytest.mark.parametrize("protocol", ["tcp", "ucx"]) +@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"]) @pytest.mark.parametrize("shared_fs", [True, False]) @gen_test(timeout=20) async def test_communicating_disk_objects(protocol, shared_fs): """Testing disk serialization of cuDF dataframe when communicating""" + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") cudf = pytest.importorskip("cudf") ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs diff --git a/dask_cuda/tests/test_utils.py b/dask_cuda/tests/test_utils.py index 34e63f1b4..a0a77677d 100644 --- a/dask_cuda/tests/test_utils.py +++ b/dask_cuda/tests/test_utils.py @@ -79,11 +79,18 @@ def test_get_device_total_memory(): assert total_mem > 0 -def test_get_preload_options_default(): - pytest.importorskip("ucp") +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) +def test_get_preload_options_default(protocol): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") opts = get_preload_options( - protocol="ucx", + protocol=protocol, create_cuda_context=True, ) @@ -93,14 +100,21 @@ def test_get_preload_options_default(): assert opts["preload_argv"] == ["--create-cuda-context"] +@pytest.mark.parametrize( + "protocol", + ["ucx", "ucxx"], +) @pytest.mark.parametrize("enable_tcp", [True, False]) @pytest.mark.parametrize("enable_infiniband", [True, False]) @pytest.mark.parametrize("enable_nvlink", [True, False]) -def test_get_preload_options(enable_tcp, enable_infiniband, enable_nvlink): - pytest.importorskip("ucp") +def test_get_preload_options(protocol, enable_tcp, enable_infiniband, enable_nvlink): + if protocol == "ucx": + pytest.importorskip("ucp") + elif protocol == "ucxx": + pytest.importorskip("ucxx") opts = get_preload_options( - protocol="ucx", + protocol=protocol, create_cuda_context=True, enable_tcp_over_ucx=enable_tcp, enable_infiniband=enable_infiniband, diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index f16ad18a2..ff4dbbae3 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -287,7 +287,7 @@ def get_preload_options( if create_cuda_context: preload_options["preload_argv"].append("--create-cuda-context") - if protocol == "ucx": + if protocol in ["ucx", "ucxx"]: initialize_ucx_argv = [] if enable_tcp_over_ucx: initialize_ucx_argv.append("--enable-tcp-over-ucx") @@ -625,6 +625,10 @@ def get_worker_config(dask_worker): import ucp ret["ucx-transports"] = ucp.get_active_transports() + elif scheme == "ucxx": + import ucxx + + ret["ucx-transports"] = ucxx.get_active_transports() # comm timeouts ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts") diff --git a/dependencies.yaml b/dependencies.yaml index 1022b3a38..02783dbff 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -122,6 +122,8 @@ dependencies: - pytest-cov - ucx-proc=*=gpu - ucx-py=0.35 + - ucxx=0.35 + - distributed-ucxx=0.35 specific: - output_types: conda matrices: