From 99d8c4e7ecccd63ec538d9d65a2885a204f2257a Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 16 Jan 2024 13:05:33 +0100 Subject: [PATCH] More dask-expr fixes --- distributed/client.py | 20 +++++++++++++------- distributed/scheduler.py | 11 +++++++++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 422b3efbf9..882b4cb905 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -33,14 +33,14 @@ import dask from dask.base import ( collections_to_dsk, + newstyle_collections, normalize_token, tokenize, - newstyle_collections, ) from dask.core import flatten, validate_key from dask.highlevelgraph import HighLevelGraph, TaskFactoryHLGWrapper from dask.optimization import SubgraphCallable -from dask.typing import no_default, DaskCollection2 +from dask.typing import DaskCollection2, no_default from dask.utils import ( apply, ensure_dict, @@ -1966,6 +1966,7 @@ def submit( futures = self._graph_to_futures( dsk, + [key], workers=workers, allow_other_workers=allow_other_workers, internal_priority={key: 0}, @@ -2171,6 +2172,7 @@ def map( futures = self._graph_to_futures( dsk, + keys, workers=workers, allow_other_workers=allow_other_workers, internal_priority=internal_priority, @@ -3109,6 +3111,7 @@ def _get_computation_code( def _graph_to_futures( self, dsk, + keys, workers=None, allow_other_workers=None, internal_priority=None, @@ -3141,9 +3144,6 @@ def _graph_to_futures( # Merge global and local annotations annotations = merge(dask.get_annotations(), annotations) - # Pack the high level graph before sending it to the scheduler - keys = dsk.__dask_output_keys__() - # Validate keys for key in keys: validate_key(key) @@ -3259,6 +3259,7 @@ def get( dsk = TaskFactoryHLGWrapper.from_low_level(dsk, list(flatten(keys))) futures = self._graph_to_futures( dsk, + keys, workers=workers, allow_other_workers=allow_other_workers, resources=resources, @@ -3449,6 +3450,7 @@ def compute( if newstyle_collections(variables): variables = [var.finalize_compute() for var in variables] dsk = collections_to_dsk(variables, optimize_graph, **kwargs) + names = dsk.__dask_output_keys__() else: dsk = collections_to_dsk(variables, optimize_graph, **kwargs) names = ["finalize-%s" % tokenize(v) for v in variables] @@ -3468,10 +3470,13 @@ def compute( layers.update(dsk.layers) dependencies = {finalize_name: set(dsk.layers.keys())} dependencies.update(dsk.dependencies) - dsk = TaskFactoryHLGWrapper(HighLevelGraph(layers, dependencies), out_keys=names) + dsk = TaskFactoryHLGWrapper( + HighLevelGraph(layers, dependencies), out_keys=names + ) futures_dict = self._graph_to_futures( dsk, + keys=names, workers=workers, allow_other_workers=allow_other_workers, resources=resources, @@ -3481,7 +3486,7 @@ def compute( actors=actors, ) - futures = list(futures_dict.values()) + futures = [futures_dict[name] for name in names] if sync: result = self.gather(futures) @@ -3570,6 +3575,7 @@ def persist( futures = self._graph_to_futures( dsk, + dsk.__dask_output_keys__(), workers=workers, allow_other_workers=allow_other_workers, resources=resources, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5ec40959fd..8983db44ba 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -55,7 +55,6 @@ from dask.core import get_deps, validate_key from dask.typing import Key, no_default from dask.utils import ( - ensure_dict, format_bytes, format_time, key_split, @@ -140,7 +139,6 @@ from typing_extensions import TypeAlias from dask.typing import TaskGraphFactory - from dask.highlevelgraph import HighLevelGraph # Not to be confused with distributed.worker_state_machine.TaskStateState TaskStateState: TypeAlias = Literal[ @@ -4670,6 +4668,15 @@ async def update_graph( try: try: graph_factory = deserialize(graph_header, graph_frames).data + if list(graph_factory.__dask_output_keys__()) != list(keys): + # Note: If we no longer want to rely on tokenization to be + # consistent beyond an interpreter session, we could + # implement this in a way that the scheduler sends a signal + # back to the client with the final key names + # This would be easiest with a version of https://github.com/dask/distributed/issues/7480 + raise RuntimeError( + "Catastrophic failure. Tokenization of keys unstable!" + ) del graph_header, graph_frames except Exception as e: msg = """\