Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Expr instead of HLG #9008

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ repos:
- tornado
- pyarrow
- urllib3
- git+https://github.com/dask/dask
- git+https://github.com/fjetter/dask@wrap_hlg_expr
- git+https://github.com/dask/zict

# Increase this value to clear the cache on GitHub actions if nothing else in this file
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dependencies:
# Temporary fix for https://github.com/jupyterlab/jupyterlab/issues/17012
- httpx<0.28.0
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/fjetter/dask@wrap_hlg_expr
- git+https://github.com/dask/zict
- git+https://github.com/dask/crick # Only tested here
# Revert after https://github.com/dask/distributed/issues/8614 is fixed
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies:
# Temporary fix for https://github.com/jupyterlab/jupyterlab/issues/17012
- httpx<0.28.0
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/fjetter/dask@wrap_hlg_expr
- git+https://github.com/dask/zict
# Revert after https://github.com/dask/distributed/issues/8614 is fixed
# - git+https://github.com/dask/s3fs
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.12.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies:
# Temporary fix for https://github.com/jupyterlab/jupyterlab/issues/17012
- httpx<0.28.0
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/fjetter/dask@wrap_hlg_expr
- git+https://github.com/dask/zict
# Revert after https://github.com/dask/distributed/issues/8614 is fixed
# - git+https://github.com/dask/s3fs
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.13.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies:
# Temporary fix for https://github.com/jupyterlab/jupyterlab/issues/17012
- httpx<0.28.0
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/fjetter/dask@wrap_hlg_expr
- git+https://github.com/dask/zict
# Revert after https://github.com/dask/distributed/issues/8614 is fixed
# - git+https://github.com/dask/s3fs
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-mindeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies:
# Distributed depends on the latest version of Dask
- pip
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/fjetter/dask@wrap_hlg_expr
# test dependencies
- pytest
- pytest-cov
Expand Down
40 changes: 11 additions & 29 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import dask
from dask.base import collections_to_dsk
from dask.core import flatten, validate_key
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.tokenize import tokenize
from dask.typing import Key, NestedKeys, NoDefault, no_default
Expand Down Expand Up @@ -3310,7 +3309,7 @@ def _inform_scheduler_of_futures(self):

def _graph_to_futures(
self,
dsk,
expr,
keys,
span_metadata,
workers=None,
Expand All @@ -3325,9 +3324,6 @@ def _graph_to_futures(
with self._refcount_lock:
if actors is not None and actors is not True and actors is not False:
actors = list(self._expand_key(actors))
# Make sure `dsk` is a high level graph
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())

annotations = {}
if user_priority:
Expand Down Expand Up @@ -3361,7 +3357,7 @@ def _graph_to_futures(
from distributed.protocol import serialize
from distributed.protocol.serialize import ToPickle

header, frames = serialize(ToPickle(dsk), on_error="raise")
header, frames = serialize(ToPickle(expr), on_error="raise")

pickled_size = sum(map(nbytes, [header] + frames))
if pickled_size > parse_bytes(
Expand All @@ -3381,8 +3377,8 @@ def _graph_to_futures(
self._send_to_scheduler(
{
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"expr_header": header,
"expr_frames": frames,
"keys": set(keys),
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
Expand Down Expand Up @@ -3665,31 +3661,17 @@ def compute(
collections=[get_collections_metadata(v) for v in variables]
)

dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
dsk2 = {}
for i, (name, v) in enumerate(zip(names, variables)):
func, extra_args = v.__dask_postcompute__()
keys = v.__dask_keys__()
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
else:
t = Task(name, func, _convert_dask_keys(keys), *extra_args)
dsk2[t.key] = t
expr = self.collections_to_dsk(variables, optimize_graph, **kwargs)
from dask._expr import FinalizeCompute

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
expr = FinalizeCompute(expr)

# Let's append the finalize graph to dsk
finalize_name = tokenize(names)
layers = {finalize_name: dsk2}
layers.update(dsk.layers)
dependencies = {finalize_name: set(dsk.layers.keys())}
dependencies.update(dsk.dependencies)
dsk = HighLevelGraph(layers, dependencies)
expr = expr.optimize()
# FIXME: Is this actually required?
names = list(flatten(expr.__dask_keys__()))

futures_dict = self._graph_to_futures(
dsk,
expr,
names,
workers=workers,
allow_other_workers=allow_other_workers,
Expand Down
30 changes: 10 additions & 20 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from dask.utils import (
_deprecated,
_deprecated_kwarg,
ensure_dict,
format_bytes,
format_time,
key_split,
Expand Down Expand Up @@ -144,7 +143,7 @@
# TODO import from typing (requires Python >=3.11)
from typing_extensions import Self, TypeAlias

from dask.highlevelgraph import HighLevelGraph
from dask._expr import Expr

# Not to be confused with distributed.worker_state_machine.TaskStateState
TaskStateState: TypeAlias = Literal[
Expand Down Expand Up @@ -4810,8 +4809,8 @@ def _remove_done_tasks_from_dsk(
async def update_graph(
self,
client: str,
graph_header: dict,
graph_frames: list[bytes],
expr_header: dict,
expr_frames: list[bytes],
keys: set[Key],
span_metadata: SpanMetadata,
internal_priority: dict[Key, int] | None,
Expand All @@ -4829,8 +4828,8 @@ async def update_graph(
try:
logger.debug("Received new graph. Deserializing...")
try:
graph = deserialize(graph_header, graph_frames).data
del graph_header, graph_frames
expr = deserialize(expr_header, expr_frames).data
del expr_header, expr_frames
except Exception as e:
msg = """\
Error during deserialization of the task graph. This frequently
Expand All @@ -4845,15 +4844,14 @@ async def update_graph(
annotations_by_type,
) = await offload(
_materialize_graph,
graph=graph,
expr=expr,
global_annotations=annotations or {},
keys=keys,
validate=self.validate,
)

materialization_done = time()
logger.debug("Materialization done. Got %i tasks.", len(dsk))
del graph
del expr

lost_keys = self._find_lost_dependencies(dsk, dependencies, keys)

Expand Down Expand Up @@ -9363,12 +9361,11 @@ def transition(


def _materialize_graph(
graph: HighLevelGraph,
expr: Expr,
global_annotations: dict[str, Any],
validate: bool,
keys: set[Key],
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
dsk: dict = ensure_dict(graph)
dsk: dict = expr.__dask_graph__()
if validate:
for k in dsk:
validate_key(k)
Expand All @@ -9377,14 +9374,7 @@ def _materialize_graph(
annotations_by_type[annotations_type].update(
{k: (value(k) if callable(value) else value) for k in dsk}
)

for layer in graph.layers.values():
if layer.annotations:
annot = layer.annotations
for annot_type, value in annot.items():
annotations_by_type[annot_type].update(
{k: (value(k) if callable(value) else value) for k in layer}
)
annotations_by_type.update(expr.__dask_annotations__())

dsk2 = convert_legacy_graph(dsk)
# FIXME: There should be no need to fully materialize and copy this but some
Expand Down
44 changes: 0 additions & 44 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4571,50 +4571,6 @@ def test_normalize_collection_with_released_futures(c):
assert res == sol


@pytest.mark.xfail(reason="https://github.com/dask/distributed/issues/4404")
@gen_cluster(client=True)
async def test_auto_normalize_collection(c, s, a, b):
da = pytest.importorskip("dask.array")

x = da.ones(10, chunks=5)
assert len(x.dask) == 2

with dask.config.set(optimizations=[c._optimize_insert_futures]):
y = x.map_blocks(inc, dtype=x.dtype)
yy = c.persist(y)

await wait(yy)

start = time()
future = c.compute(y.sum())
await future
end = time()
assert end - start < 1

start = time()
z = c.persist(y + 1)
await wait(z)
end = time()
assert end - start < 1


@pytest.mark.xfail(reason="https://github.com/dask/distributed/issues/4404")
def test_auto_normalize_collection_sync(c):
da = pytest.importorskip("dask.array")
x = da.ones(10, chunks=5)

y = x.map_blocks(inc, dtype=x.dtype)
yy = c.persist(y)

wait(yy)

with dask.config.set(optimizations=[c._optimize_insert_futures]):
start = time()
y.sum().compute()
end = time()
assert end - start < 1


def assert_no_data_loss(scheduler):
for key, start, finish, recommendations, _, _ in scheduler.transition_log:
if start == "memory" and finish == "released":
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,8 +1422,8 @@ async def test_update_graph_culls(s, a, b):

header, frames = serialize(ToPickle(dsk), on_error="raise")
await s.update_graph(
graph_header=header,
graph_frames=frames,
expr_header=header,
expr_frames=frames,
keys=["y"],
client="client",
internal_priority={k: 0 for k in "xyz"},
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpydoc
tornado
toolz
cloudpickle
git+https://github.com/dask/dask
git+https://github.com/fjetter/dask@wrap_hlg_expr
sphinx
dask-sphinx-theme>=3.0.0
# FIXME: `sphinxcontrib-*` pins are a workaround until we have sphinx>=5.
Expand Down
Loading