diff --git a/dask_ctl/discovery.py b/dask_ctl/discovery.py index 027eb79..c03624b 100644 --- a/dask_ctl/discovery.py +++ b/dask_ctl/discovery.py @@ -1,3 +1,4 @@ +import asyncio from typing import Callable, Dict, AsyncIterator, Tuple from contextlib import suppress import pkg_resources @@ -29,7 +30,7 @@ def list_discovery_methods() -> Dict[str, Callable]: 'version': '', 'path': ''} } - >>> list(list_discovery_methods()) + >>> list(list_discovery_methods()) # doctest: +SKIP ['proxycluster'] """ @@ -126,4 +127,6 @@ async def discover_clusters(discovery=None) -> AsyncIterator[SpecCluster]: """ async for cluster_name, cluster_class in discover_cluster_names(discovery): with suppress(Exception), suppress_output(): - yield cluster_class.from_name(cluster_name) + yield await asyncio.get_running_loop().run_in_executor( + None, cluster_class.from_name, cluster_name + ) diff --git a/dask_ctl/lifecycle.py b/dask_ctl/lifecycle.py index d931908..4c1c2d8 100644 --- a/dask_ctl/lifecycle.py +++ b/dask_ctl/lifecycle.py @@ -99,7 +99,7 @@ def get_cluster(name: str) -> Cluster: async def _get_cluster(): async for cluster_name, cluster_class in discover_cluster_names(): if cluster_name == name: - return cluster_class.from_name(name) + return await loop.run_in_executor(None, cluster_class.from_name, name) raise RuntimeError("No such cluster %s", name) return loop.run_sync(_get_cluster) @@ -127,7 +127,10 @@ def scale_cluster(name: str, n_workers: int) -> None: async def _scale_cluster(): async for cluster_name, cluster_class in discover_cluster_names(): if cluster_name == name: - return cluster_class.from_name(name).scale(n_workers) + cluster = await loop.run_in_executor( + None, cluster_class.from_name, name + ) + return await loop.run_in_executor(None, cluster.scale, n_workers) raise RuntimeError("No such cluster %s", name) return loop.run_sync(_scale_cluster) @@ -153,7 +156,10 @@ def delete_cluster(name: str) -> None: async def _delete_cluster(): async for cluster_name, cluster_class in discover_cluster_names(): if cluster_name == name: - return cluster_class.from_name(name).close() + cluster = await loop.run_in_executor( + None, cluster_class.from_name, name + ) + return await loop.run_in_executor(None, cluster.close) raise RuntimeError("No such cluster %s", name) return loop.run_sync(_delete_cluster) diff --git a/dask_ctl/proxy.py b/dask_ctl/proxy.py index 75be3dd..38307b7 100644 --- a/dask_ctl/proxy.py +++ b/dask_ctl/proxy.py @@ -1,17 +1,19 @@ from typing import Callable, AsyncIterator, Tuple import asyncio -import contextlib -import psutil +from zeroconf import ( + IPVersion, + ServiceInfo, + Zeroconf, +) +from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf from distributed.deploy.cluster import Cluster from distributed.core import rpc, Status -from distributed.client import Client from distributed.utils import LoopRunner -def gen_name(port): - return f"proxycluster-{port}" +_ZC_SERVICE = "_dask._tcp.local." class ProxyCluster(Cluster): @@ -52,41 +54,19 @@ def from_name( ProxyCluster(proxycluster-8786, 'tcp://localhost:8786', workers=4, threads=12, memory=17.18 GB) """ - port = name.split("-")[-1] - return cls.from_port(port, loop=loop, asynchronous=asynchronous) - @classmethod - def from_port( - cls, port: int, loop: asyncio.BaseEventLoop = None, asynchronous: bool = False - ): - """Get instance of ``ProxyCluster`` by port. - - Parameters - ---------- - port - Localhost port of cluster to get ``ProxyCluster`` for. - loop (optional) - Existing event loop to use. - asynchronous (optional) - Start asynchronously. Default ``False``. - - Returns - ------- - ProxyCluster - Instance of ProxyCluster. - - Examples - -------- - >>> from dask.distributed import LocalCluster # doctest: +SKIP - >>> cluster = LocalCluster(scheduler_port=81234) # doctest: +SKIP - >>> ProxyCluster.from_port(81234) # doctest: +SKIP - ProxyCluster(proxycluster-81234, 'tcp://localhost:81234', workers=4, threads=12, memory=17.18 GB) - - """ cluster = cls(asynchronous=asynchronous) - cluster.name = gen_name(port) - - cluster.scheduler_comm = rpc(f"tcp://localhost:{port}") + cluster.name = name + + # Get scheduler address via zeroconf + zeroconf = Zeroconf(ip_version=IPVersion.V4Only) + scheduler = ServiceInfo(_ZC_SERVICE, f"{name}._dask._tcp.local.") + if not scheduler.request(zeroconf, 3000): + raise RuntimeError("Unable to find cluster") + addr = scheduler.parsed_addresses()[0] + protocol = scheduler.properties[b"protocol"].decode("utf-8") + cluster.scheduler_comm = rpc(f"{protocol}://{addr}:{scheduler.port}") + zeroconf.close() cluster._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) cluster.loop = cluster._loop_runner.loop @@ -97,12 +77,6 @@ def from_port( cluster.sync(cluster._start) return cluster - def scale(self, *args, **kwargs): - raise TypeError("Scaling of ProxyCluster objects is not supported.") - - def close(self, *args, **kwargs): - raise TypeError("Closing of ProxyCluster objects is not supported.") - async def discover() -> AsyncIterator[Tuple[str, Callable]]: """Discover proxy clusters. @@ -133,33 +107,25 @@ async def discover() -> AsyncIterator[Tuple[str, Callable]]: [('proxycluster-8786', dask_ctl.proxy.ProxyCluster)] """ - open_ports = {8786} - - with contextlib.suppress( - psutil.AccessDenied - ): # On macOS this needs to be run as root - connections = psutil.net_connections() - for connection in connections: - if ( - connection.status == "LISTEN" - and connection.family.name == "AF_INET" - and connection.laddr.port not in open_ports - ): - open_ports.add(connection.laddr.port) - - async def try_connect(port): - with contextlib.suppress(OSError, asyncio.TimeoutError): - async with Client( - f"tcp://localhost:{port}", - asynchronous=True, - timeout=1, # Minimum of 1 for Windows - ): - return port - return - - for port in await asyncio.gather(*[try_connect(port) for port in open_ports]): - if port: - yield ( - gen_name(port), - ProxyCluster, - ) + aiozc = AsyncZeroconf(ip_version=IPVersion.V4Only) + browser = AsyncServiceBrowser( + aiozc.zeroconf, [_ZC_SERVICE], handlers=[lambda *args, **kw: None] + ) + + # ServiceBrowser runs in a thread. Give it a chance to find some schedulers. + await asyncio.sleep(0.5) + + schedulers = [ + x.split(".")[0] + for x in aiozc.zeroconf.cache.names() + if x.endswith(_ZC_SERVICE) and x != _ZC_SERVICE + ] + + for scheduler in schedulers: + yield ( + scheduler, + ProxyCluster, + ) + + await browser.async_cancel() + await aiozc.async_close() diff --git a/dask_ctl/tests/test_cli.py b/dask_ctl/tests/test_cli.py index ce9998e..59cbec5 100644 --- a/dask_ctl/tests/test_cli.py +++ b/dask_ctl/tests/test_cli.py @@ -21,6 +21,7 @@ def test_create(simple_spec_path): def test_autocompletion(): with LocalCluster(scheduler_port=8786) as _: - assert len(autocomplete_cluster_names(None, None, "")) == 1 - assert len(autocomplete_cluster_names(None, None, "proxy")) == 1 + names = autocomplete_cluster_names(None, None, "") + assert len(names) == 1 + assert "_sched" in names[0] assert len(autocomplete_cluster_names(None, None, "local")) == 0 diff --git a/dask_ctl/tests/test_discovery.py b/dask_ctl/tests/test_discovery.py index eaa5e39..f23bc46 100644 --- a/dask_ctl/tests/test_discovery.py +++ b/dask_ctl/tests/test_discovery.py @@ -2,13 +2,21 @@ from typing import AsyncIterator -from distributed import LocalCluster +from dask.distributed import Client, LocalCluster from dask_ctl.discovery import ( discover_cluster_names, discover_clusters, list_discovery_methods, ) -from dask_ctl.proxy import ProxyCluster +from dask_ctl.proxy import discover + + +@pytest.mark.asyncio +async def test_discover_clusters(): + assert isinstance(discover_clusters(), AsyncIterator) + async with LocalCluster(scheduler_port=8786, asynchronous=True) as cluster: + [discovered_cluster] = [c async for c in discover_clusters()] + assert discovered_cluster.scheduler_info == cluster.scheduler_info def test_discovery_methods(): @@ -19,16 +27,12 @@ def test_discovery_methods(): async def test_discover_cluster_names(): assert isinstance(discover_cluster_names(), AsyncIterator) async with LocalCluster(scheduler_port=8786, asynchronous=True) as _: - count = 0 - async for _ in discover_cluster_names(): - count += 1 - assert count == 1 + names = [name async for name in discover_cluster_names()] + assert len(names) == 1 @pytest.mark.asyncio async def test_cluster_client(): - from dask.distributed import Client - port = 8786 async with LocalCluster(scheduler_port=port, asynchronous=True) as _: async with Client( @@ -39,17 +43,6 @@ async def test_cluster_client(): @pytest.mark.asyncio async def test_discovery_list(): - from dask_ctl.proxy import discover - - port = 8786 - async with LocalCluster(scheduler_port=port, asynchronous=True) as _: + async with LocalCluster(scheduler_port=8786, asynchronous=True) as _: async for name, _ in discover(): - assert str(port) in name - - -@pytest.mark.asyncio -async def test_discover_clusters(): - with LocalCluster() as cluster: - async for discovered_cluster in discover_clusters(): - if isinstance(discovered_cluster, ProxyCluster): - assert cluster.scheduler_info == discovered_cluster.scheduler_info + assert "_sched" in name diff --git a/dask_ctl/utils.py b/dask_ctl/utils.py index 62a39a0..a852a1a 100644 --- a/dask_ctl/utils.py +++ b/dask_ctl/utils.py @@ -1,5 +1,6 @@ from contextlib import contextmanager, redirect_stdout, redirect_stderr from io import StringIO +import os from tornado.ioloop import IOLoop from distributed.cli.utils import install_signal_handlers @@ -55,5 +56,8 @@ def justify(value, length): @contextmanager def suppress_output(): - with redirect_stdout(StringIO()), redirect_stderr(StringIO()): + if "DASK_CTL_DEBUG" in os.environ: yield + else: + with redirect_stdout(StringIO()), redirect_stderr(StringIO()): + yield diff --git a/requirements.txt b/requirements.txt index eda8803..63e3b26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ click dask distributed tornado -pyyaml \ No newline at end of file +pyyaml +zeroconf>=0.32.0 \ No newline at end of file