Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zeroconf discovery support #10

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions dask_ctl/discovery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Callable, Dict, AsyncIterator, Tuple
from contextlib import suppress
import pkg_resources
Expand Down Expand Up @@ -29,7 +30,7 @@ def list_discovery_methods() -> Dict[str, Callable]:
'version': '<package version>',
'path': '<path to package>'}
}
>>> list(list_discovery_methods())
>>> list(list_discovery_methods()) # doctest: +SKIP
['proxycluster']

"""
Expand Down Expand Up @@ -126,4 +127,6 @@ async def discover_clusters(discovery=None) -> AsyncIterator[SpecCluster]:
"""
async for cluster_name, cluster_class in discover_cluster_names(discovery):
with suppress(Exception), suppress_output():
yield cluster_class.from_name(cluster_name)
yield await asyncio.get_running_loop().run_in_executor(
None, cluster_class.from_name, cluster_name
)
12 changes: 9 additions & 3 deletions dask_ctl/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_cluster(name: str) -> Cluster:
async def _get_cluster():
async for cluster_name, cluster_class in discover_cluster_names():
if cluster_name == name:
return cluster_class.from_name(name)
return await loop.run_in_executor(None, cluster_class.from_name, name)
raise RuntimeError("No such cluster %s", name)

return loop.run_sync(_get_cluster)
Expand Down Expand Up @@ -127,7 +127,10 @@ def scale_cluster(name: str, n_workers: int) -> None:
async def _scale_cluster():
async for cluster_name, cluster_class in discover_cluster_names():
if cluster_name == name:
return cluster_class.from_name(name).scale(n_workers)
cluster = await loop.run_in_executor(
None, cluster_class.from_name, name
)
return await loop.run_in_executor(None, cluster.scale, n_workers)
raise RuntimeError("No such cluster %s", name)

return loop.run_sync(_scale_cluster)
Expand All @@ -153,7 +156,10 @@ def delete_cluster(name: str) -> None:
async def _delete_cluster():
async for cluster_name, cluster_class in discover_cluster_names():
if cluster_name == name:
return cluster_class.from_name(name).close()
cluster = await loop.run_in_executor(
None, cluster_class.from_name, name
)
return await loop.run_in_executor(None, cluster.close)
raise RuntimeError("No such cluster %s", name)

return loop.run_sync(_delete_cluster)
114 changes: 40 additions & 74 deletions dask_ctl/proxy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import Callable, AsyncIterator, Tuple
import asyncio
import contextlib

import psutil
from zeroconf import (
IPVersion,
ServiceInfo,
Zeroconf,
)
from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf

from distributed.deploy.cluster import Cluster
from distributed.core import rpc, Status
from distributed.client import Client
from distributed.utils import LoopRunner


def gen_name(port):
return f"proxycluster-{port}"
_ZC_SERVICE = "_dask._tcp.local."


class ProxyCluster(Cluster):
Expand Down Expand Up @@ -52,41 +54,19 @@ def from_name(
ProxyCluster(proxycluster-8786, 'tcp://localhost:8786', workers=4, threads=12, memory=17.18 GB)

"""
port = name.split("-")[-1]
return cls.from_port(port, loop=loop, asynchronous=asynchronous)

@classmethod
def from_port(
cls, port: int, loop: asyncio.BaseEventLoop = None, asynchronous: bool = False
):
"""Get instance of ``ProxyCluster`` by port.

Parameters
----------
port
Localhost port of cluster to get ``ProxyCluster`` for.
loop (optional)
Existing event loop to use.
asynchronous (optional)
Start asynchronously. Default ``False``.

Returns
-------
ProxyCluster
Instance of ProxyCluster.

Examples
--------
>>> from dask.distributed import LocalCluster # doctest: +SKIP
>>> cluster = LocalCluster(scheduler_port=81234) # doctest: +SKIP
>>> ProxyCluster.from_port(81234) # doctest: +SKIP
ProxyCluster(proxycluster-81234, 'tcp://localhost:81234', workers=4, threads=12, memory=17.18 GB)

"""
cluster = cls(asynchronous=asynchronous)
cluster.name = gen_name(port)

cluster.scheduler_comm = rpc(f"tcp://localhost:{port}")
cluster.name = name

# Get scheduler address via zeroconf
zeroconf = Zeroconf(ip_version=IPVersion.V4Only)
scheduler = ServiceInfo(_ZC_SERVICE, f"{name}._dask._tcp.local.")
if not scheduler.request(zeroconf, 3000):
raise RuntimeError("Unable to find cluster")
addr = scheduler.parsed_addresses()[0]
protocol = scheduler.properties[b"protocol"].decode("utf-8")
cluster.scheduler_comm = rpc(f"{protocol}://{addr}:{scheduler.port}")
zeroconf.close()

cluster._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
cluster.loop = cluster._loop_runner.loop
Expand All @@ -97,12 +77,6 @@ def from_port(
cluster.sync(cluster._start)
return cluster

def scale(self, *args, **kwargs):
raise TypeError("Scaling of ProxyCluster objects is not supported.")

def close(self, *args, **kwargs):
raise TypeError("Closing of ProxyCluster objects is not supported.")


async def discover() -> AsyncIterator[Tuple[str, Callable]]:
"""Discover proxy clusters.
Expand Down Expand Up @@ -133,33 +107,25 @@ async def discover() -> AsyncIterator[Tuple[str, Callable]]:
[('proxycluster-8786', dask_ctl.proxy.ProxyCluster)]

"""
open_ports = {8786}

with contextlib.suppress(
psutil.AccessDenied
): # On macOS this needs to be run as root
connections = psutil.net_connections()
for connection in connections:
if (
connection.status == "LISTEN"
and connection.family.name == "AF_INET"
and connection.laddr.port not in open_ports
):
open_ports.add(connection.laddr.port)

async def try_connect(port):
with contextlib.suppress(OSError, asyncio.TimeoutError):
async with Client(
f"tcp://localhost:{port}",
asynchronous=True,
timeout=1, # Minimum of 1 for Windows
):
return port
return

for port in await asyncio.gather(*[try_connect(port) for port in open_ports]):
if port:
yield (
gen_name(port),
ProxyCluster,
)
aiozc = AsyncZeroconf(ip_version=IPVersion.V4Only)
browser = AsyncServiceBrowser(
aiozc.zeroconf, [_ZC_SERVICE], handlers=[lambda *args, **kw: None]
)

# ServiceBrowser runs in a thread. Give it a chance to find some schedulers.
await asyncio.sleep(0.5)

schedulers = [
x.split(".")[0]
for x in aiozc.zeroconf.cache.names()
if x.endswith(_ZC_SERVICE) and x != _ZC_SERVICE
]

for scheduler in schedulers:
yield (
scheduler,
ProxyCluster,
)

await browser.async_cancel()
await aiozc.async_close()
5 changes: 3 additions & 2 deletions dask_ctl/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_create(simple_spec_path):

def test_autocompletion():
with LocalCluster(scheduler_port=8786) as _:
assert len(autocomplete_cluster_names(None, None, "")) == 1
assert len(autocomplete_cluster_names(None, None, "proxy")) == 1
names = autocomplete_cluster_names(None, None, "")
assert len(names) == 1
assert "_sched" in names[0]
assert len(autocomplete_cluster_names(None, None, "local")) == 0
35 changes: 14 additions & 21 deletions dask_ctl/tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@

from typing import AsyncIterator

from distributed import LocalCluster
from dask.distributed import Client, LocalCluster
from dask_ctl.discovery import (
discover_cluster_names,
discover_clusters,
list_discovery_methods,
)
from dask_ctl.proxy import ProxyCluster
from dask_ctl.proxy import discover


@pytest.mark.asyncio
async def test_discover_clusters():
assert isinstance(discover_clusters(), AsyncIterator)
async with LocalCluster(scheduler_port=8786, asynchronous=True) as cluster:
[discovered_cluster] = [c async for c in discover_clusters()]
assert discovered_cluster.scheduler_info == cluster.scheduler_info


def test_discovery_methods():
Expand All @@ -19,16 +27,12 @@ def test_discovery_methods():
async def test_discover_cluster_names():
assert isinstance(discover_cluster_names(), AsyncIterator)
async with LocalCluster(scheduler_port=8786, asynchronous=True) as _:
count = 0
async for _ in discover_cluster_names():
count += 1
assert count == 1
names = [name async for name in discover_cluster_names()]
assert len(names) == 1


@pytest.mark.asyncio
async def test_cluster_client():
from dask.distributed import Client

port = 8786
async with LocalCluster(scheduler_port=port, asynchronous=True) as _:
async with Client(
Expand All @@ -39,17 +43,6 @@ async def test_cluster_client():

@pytest.mark.asyncio
async def test_discovery_list():
from dask_ctl.proxy import discover

port = 8786
async with LocalCluster(scheduler_port=port, asynchronous=True) as _:
async with LocalCluster(scheduler_port=8786, asynchronous=True) as _:
async for name, _ in discover():
assert str(port) in name


@pytest.mark.asyncio
async def test_discover_clusters():
with LocalCluster() as cluster:
async for discovered_cluster in discover_clusters():
if isinstance(discovered_cluster, ProxyCluster):
assert cluster.scheduler_info == discovered_cluster.scheduler_info
assert "_sched" in name
6 changes: 5 additions & 1 deletion dask_ctl/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager, redirect_stdout, redirect_stderr
from io import StringIO
import os

from tornado.ioloop import IOLoop
from distributed.cli.utils import install_signal_handlers
Expand Down Expand Up @@ -55,5 +56,8 @@ def justify(value, length):

@contextmanager
def suppress_output():
with redirect_stdout(StringIO()), redirect_stderr(StringIO()):
if "DASK_CTL_DEBUG" in os.environ:
yield
else:
with redirect_stdout(StringIO()), redirect_stderr(StringIO()):
yield
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ click
dask
distributed
tornado
pyyaml
pyyaml
zeroconf>=0.32.0