Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for UCXX #1268

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Frac-match", value=f"{args.frac_match}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Protocol", value=f"{args.protocol}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cupy_map_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Protocol", value=f"{args.protocol}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
cluster_args.add_argument(
"-p",
"--protocol",
choices=["tcp", "ucx"],
choices=["tcp", "ucx", "ucxx"],
default="tcp",
type=str,
help="The communication protocol to use.",
Expand Down
63 changes: 47 additions & 16 deletions dask_cuda/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numba.cuda

import dask
import distributed.comm.ucx
from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context

from .utils import get_ucx_config
Expand All @@ -23,12 +22,21 @@ def _create_cuda_context_handler():
numba.cuda.current_context()


def _create_cuda_context():
def _create_cuda_context(protocol="ucx"):
if protocol not in ["ucx", "ucxx"]:
return
try:
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
# context directly from the UCX module, thus avoiding a similar warning there.
try:
distributed.comm.ucx.init_once()
if protocol == "ucx":
import distributed.comm.ucx

distributed.comm.ucx.init_once()
elif protocol == "ucxx":
import distributed_ucxx.ucxx

distributed_ucxx.ucxx.init_once()
except ModuleNotFoundError:
# UCX initialization has to be delegated to Distributed, it will take care
# of setting correct environment variables and importing `ucp` after that.
Expand All @@ -39,20 +47,35 @@ def _create_cuda_context():
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
)
ctx = has_cuda_context()
if (
ctx.has_context
and not distributed.comm.ucx.cuda_context_created.has_context
):
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
if protocol == "ucx":
if (
ctx.has_context
and not distributed.comm.ucx.cuda_context_created.has_context
):
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
elif protocol == "ucxx":
if (
ctx.has_context
and not distributed_ucxx.ucxx.cuda_context_created.has_context
):
distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())

_create_cuda_context_handler()

if not distributed.comm.ucx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed.comm.ucx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)
if protocol == "ucx":
if not distributed.comm.ucx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed.comm.ucx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)
elif protocol == "ucxx":
if not distributed_ucxx.ucxx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)

except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)
Expand All @@ -64,6 +87,7 @@ def initialize(
enable_infiniband=None,
enable_nvlink=None,
enable_rdmacm=None,
protocol="ucx",
):
"""Create CUDA context and initialize UCX-Py, depending on user parameters.

Expand Down Expand Up @@ -118,7 +142,7 @@ def initialize(
dask.config.set({"distributed.comm.ucx": ucx_config})

if create_cuda_context:
_create_cuda_context()
_create_cuda_context(protocol=protocol)


@click.command()
Expand All @@ -127,6 +151,12 @@ def initialize(
default=False,
help="Create CUDA context",
)
@click.option(
"--protocol",
default=None,
type=str,
help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.",
)
@click.option(
"--enable-tcp-over-ucx/--disable-tcp-over-ucx",
default=False,
Expand All @@ -150,10 +180,11 @@ def initialize(
def dask_setup(
service,
create_cuda_context,
protocol,
enable_tcp_over_ucx,
enable_infiniband,
enable_nvlink,
enable_rdmacm,
):
if create_cuda_context:
_create_cuda_context()
_create_cuda_context(protocol=protocol)
9 changes: 6 additions & 3 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,11 @@ def __init__(
if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
if protocol is None:
protocol = "ucx"
elif protocol != "ucx":
raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'")
elif protocol not in ["ucx", "ucxx"]:
raise TypeError(
"Enabling InfiniBand or NVLink requires protocol='ucx' or "
"protocol='ucxx'"
)

self.host = kwargs.get("host", None)

Expand Down Expand Up @@ -371,7 +374,7 @@ def __init__(
) + ["dask_cuda.initialize"]
self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get(
"preload_argv", []
) + ["--create-cuda-context"]
) + ["--create-cuda-context", "--protocol", protocol]

self.cuda_visible_devices = CUDA_VISIBLE_DEVICES
self.scale(n_workers)
Expand Down
42 changes: 32 additions & 10 deletions dask_cuda/tests/test_dgx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ def test_default():
assert not p.exitcode


def _test_tcp_over_ucx():
ucp = pytest.importorskip("ucp")
def _test_tcp_over_ucx(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster:
with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
with Client(cluster) as client:
res = da.from_array(numpy.arange(10000), chunks=(1000,))
res = res.sum().compute()
Expand All @@ -93,10 +96,17 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


def test_tcp_over_ucx():
ucp = pytest.importorskip("ucp") # NOQA: F841
@pytest.mark.parametrize(
"protocol",
["ucx", "ucxx"],
)
def test_tcp_over_ucx(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_tcp_over_ucx)
p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
p.start()
p.join()
assert not p.exitcode
Expand All @@ -117,9 +127,14 @@ def test_tcp_only():
assert not p.exitcode


def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm):
def _test_ucx_infiniband_nvlink(
protocol, enable_infiniband, enable_nvlink, enable_rdmacm
):
cupy = pytest.importorskip("cupy")
ucp = pytest.importorskip("ucp")
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
enable_tcp_over_ucx = None
Expand All @@ -135,13 +150,15 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
cm_tls_priority = ["tcp"]

initialize(
protocol=protocol,
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
enable_nvlink=enable_nvlink,
enable_rdmacm=enable_rdmacm,
)

with LocalCUDACluster(
protocol=protocol,
interface="ib0",
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
Expand Down Expand Up @@ -171,6 +188,7 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
@pytest.mark.parametrize(
"params",
[
Expand All @@ -185,8 +203,11 @@ def check_ucx_options():
_get_dgx_version() == DGXVersion.DGX_A100,
reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
)
def test_ucx_infiniband_nvlink(params):
ucp = pytest.importorskip("ucp") # NOQA: F841
def test_ucx_infiniband_nvlink(protocol, params):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

if params["enable_infiniband"]:
if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
Expand All @@ -195,6 +216,7 @@ def test_ucx_infiniband_nvlink(params):
p = mp.Process(
target=_test_ucx_infiniband_nvlink,
args=(
protocol,
params["enable_infiniband"],
params["enable_nvlink"],
params["enable_rdmacm"],
Expand Down
8 changes: 4 additions & 4 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _test_local_cluster(protocol):
assert sum(c.run(my_rank, 0)) == sum(range(4))


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_local_cluster(protocol):
p = mp.Process(target=_test_local_cluster, args=(protocol,))
p.start()
Expand Down Expand Up @@ -160,7 +160,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):

@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
Expand Down Expand Up @@ -256,7 +256,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):

@pytest.mark.parametrize("nworkers", [1, 2, 4])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_dataframe_shuffle_merge(backend, protocol, nworkers):
if backend == "cudf":
pytest.importorskip("cudf")
Expand Down Expand Up @@ -293,7 +293,7 @@ def _test_jit_unspill(protocol):
assert_eq(got, expected)


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_jit_unspill(protocol):
pytest.importorskip("cudf")

Expand Down
8 changes: 6 additions & 2 deletions dask_cuda/tests/test_from_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

from dask_cuda import LocalCUDACluster

pytest.importorskip("ucp")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
def test_ucx_from_array(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

N = 10_000
with LocalCUDACluster(protocol=protocol) as cluster:
with Client(cluster):
Expand Down
Loading
Loading