Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to use lifecycle methods both sync and async #40

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dask_ctl/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .lifecycle import _create_cluster as create_cluster # noqa
from .lifecycle import _list_clusters as list_clusters # noqa
from .lifecycle import _get_cluster as get_cluster # noqa
from .lifecycle import _get_snippet as get_snippet # noqa
from .lifecycle import _scale_cluster as scale_cluster # noqa
from .lifecycle import _delete_cluster as delete_cluster # noqa
42 changes: 19 additions & 23 deletions dask_ctl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from distributed.core import Status

from . import __version__
from .utils import loop
from .utils import run_sync
from .discovery import (
discover_clusters,
discover_cluster_names,
Expand All @@ -38,7 +38,7 @@ async def _autocomplete_cluster_names():
if incomplete in cluster
]

return loop.run_sync(_autocomplete_cluster_names)
return run_sync(_autocomplete_cluster_names)


@click.group()
Expand Down Expand Up @@ -144,7 +144,7 @@ async def _list():

console.print(table)

loop.run_sync(_list)
run_sync(_list)


@cluster.command()
Expand Down Expand Up @@ -273,26 +273,22 @@ def list_discovery():
methods registered on your system.

"""

async def _list_discovery():
table = Table(box=box.SIMPLE)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Package", justify="right", style="magenta")
table.add_column("Version", style="green")
table.add_column("Path", style="yellow")
table.add_column("Enabled", justify="right", style="green")

for method_name, method in list_discovery_methods().items():
table.add_row(
method_name,
method["package"],
method["version"],
method["path"],
":heavy_check_mark:" if method["enabled"] else ":cross_mark:",
)
console.print(table)

loop.run_sync(_list_discovery)
table = Table(box=box.SIMPLE)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Package", justify="right", style="magenta")
table.add_column("Version", style="green")
table.add_column("Path", style="yellow")
table.add_column("Enabled", justify="right", style="green")

for method_name, method in list_discovery_methods().items():
table.add_row(
method_name,
method["package"],
method["version"],
method["path"],
":heavy_check_mark:" if method["enabled"] else ":cross_mark:",
)
console.print(table)


@discovery.command(name="enable")
Expand Down
82 changes: 57 additions & 25 deletions dask_ctl/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from distributed.deploy.cluster import Cluster
from .discovery import discover_cluster_names, discover_clusters
from .spec import load_spec
from .utils import loop
from .utils import run_sync


def create_cluster(spec_path: str) -> Cluster:
Expand Down Expand Up @@ -38,18 +38,22 @@ def create_cluster(spec_path: str) -> Cluster:

"""

async def _create_cluster():
cm_module, cm_class, args, kwargs = load_spec(spec_path)
module = importlib.import_module(cm_module)
cluster_manager = getattr(module, cm_class)
return run_sync(_create_cluster, spec_path)

kwargs = {key.replace("-", "_"): entry for key, entry in kwargs.items()}

cluster = await cluster_manager(*args, **kwargs, asynchronous=True)
cluster.shutdown_on_close = False
return cluster
async def _create_cluster(spec_path: str) -> Cluster:
cm_module, cm_class, args, kwargs = load_spec(spec_path)
module = importlib.import_module(cm_module)
cluster_manager = getattr(module, cm_class)

return loop.run_sync(_create_cluster)
kwargs = {key.replace("-", "_"): entry for key, entry in kwargs.items()}

cluster = await cluster_manager(*args, **kwargs, asynchronous=True)
cluster.shutdown_on_close = False
return cluster


_create_cluster.__doc__ = create_cluster.__doc__


def list_clusters() -> List[Cluster]:
Expand All @@ -71,13 +75,17 @@ def list_clusters() -> List[Cluster]:

"""

async def _list_clusters():
clusters = []
async for cluster in discover_clusters():
clusters.append(cluster)
return clusters
return run_sync(_list_clusters)


async def _list_clusters() -> List[Cluster]:
clusters = []
async for cluster in discover_clusters():
clusters.append(cluster)
return clusters

return loop.run_sync(_list_clusters)

_list_clusters.__doc__ = list_clusters.__doc__


def get_cluster(name: str) -> Cluster:
Expand All @@ -102,13 +110,17 @@ def get_cluster(name: str) -> Cluster:

"""

async def _get_cluster():
async for cluster_name, cluster_class in discover_cluster_names():
if cluster_name == name:
return cluster_class.from_name(name)
raise RuntimeError("No such cluster %s", name)
return run_sync(_get_cluster, name)


async def _get_cluster(name: str) -> Cluster:
async for cluster_name, cluster_class in discover_cluster_names():
if cluster_name == name:
return cluster_class.from_name(name)
raise RuntimeError("No such cluster %s", name)


return loop.run_sync(_get_cluster)
_get_cluster.__doc__ = get_cluster.__doc__


def get_snippet(name: str) -> str:
Expand Down Expand Up @@ -136,8 +148,11 @@ def get_snippet(name: str) -> str:
client = Client(cluster)

"""
return run_sync(_get_snippet, name)

cluster = get_cluster(name)

async def _get_snippet(name: str) -> str:
cluster = await _get_cluster(name)
try:
return cluster.get_snippet()
except AttributeError:
Expand All @@ -148,6 +163,9 @@ def get_snippet(name: str) -> str:
)


_get_snippet.__doc__ = get_snippet.__doc__


def scale_cluster(name: str, n_workers: int) -> None:
"""Scale a cluster by name.

Expand All @@ -166,8 +184,15 @@ def scale_cluster(name: str, n_workers: int) -> None:
>>> scale_cluster("mycluster", 10) # doctest: +SKIP

"""
return run_sync(_scale_cluster, name, n_workers)


async def _scale_cluster(name: str, n_workers: int) -> None:
cluster = await _get_cluster(name)
return await cluster.scale(n_workers)

return get_cluster(name).scale(n_workers)

_scale_cluster.__doc__ = scale_cluster.__doc__


def delete_cluster(name: str) -> None:
Expand All @@ -186,5 +211,12 @@ def delete_cluster(name: str) -> None:
>>> delete_cluster("mycluster") # doctest: +SKIP

"""
return run_sync(_delete_cluster, name)


async def _delete_cluster(name: str) -> None:
cluster = await _get_cluster(name)
return await cluster.close()


return get_cluster(name).close()
_delete_cluster.__doc__ = _delete_cluster.__doc__
19 changes: 15 additions & 4 deletions dask_ctl/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import asyncio
import concurrent.futures

from tornado.ioloop import IOLoop
from distributed.cli.utils import install_signal_handlers

def run_sync(f, *args, **kwargs):
async def canary():
"""An empty coroutine to check if we are inside an event loop"""
pass

loop = IOLoop.current()
install_signal_handlers(loop)
try:
asyncio.run(canary())
except RuntimeError:
# event loop is already running and not running with jupyter (eg nest-asyncio)
pass
else:
return asyncio.run(f(*args, **kwargs))

with concurrent.futures.ThreadPoolExecutor(1) as tpe:
return tpe.submit(asyncio.run, f(*args, **kwargs)).result()


class _AsyncTimedIterator:
Expand Down
19 changes: 19 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@ Python API
Lifecycle
---------

Dask Control has a selection of lifecycle functions that can be used within Python to manage
your Dask clusters. You can list clusters, get instances of an existing cluster, create new ones, scale and delete them.

You can either use these in a regular synchronous way by importing them from ``dask_ctl``.

.. code-block:: python

from dask_ctl import list_clusters

clusters = list_clusters()

Or alternatively you can use them in async code by importing from the ``dask_ctl.asyncio`` submodule.

.. code-block:: python

from dask_ctl.asyncio import list_clusters

clusters = await list_clusters()

.. autosummary::
get_cluster
create_cluster
Expand Down