Skip to content

Commit

Permalink
More dask-expr fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jan 16, 2024
1 parent 4c3d346 commit 99d8c4e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
20 changes: 13 additions & 7 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
import dask
from dask.base import (
collections_to_dsk,
newstyle_collections,
normalize_token,
tokenize,
newstyle_collections,
)
from dask.core import flatten, validate_key
from dask.highlevelgraph import HighLevelGraph, TaskFactoryHLGWrapper
from dask.optimization import SubgraphCallable
from dask.typing import no_default, DaskCollection2
from dask.typing import DaskCollection2, no_default
from dask.utils import (
apply,
ensure_dict,
Expand Down Expand Up @@ -1966,6 +1966,7 @@ def submit(

futures = self._graph_to_futures(
dsk,
[key],
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority={key: 0},
Expand Down Expand Up @@ -2171,6 +2172,7 @@ def map(

futures = self._graph_to_futures(
dsk,
keys,
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority=internal_priority,
Expand Down Expand Up @@ -3109,6 +3111,7 @@ def _get_computation_code(
def _graph_to_futures(
self,
dsk,
keys,
workers=None,
allow_other_workers=None,
internal_priority=None,
Expand Down Expand Up @@ -3141,9 +3144,6 @@ def _graph_to_futures(
# Merge global and local annotations
annotations = merge(dask.get_annotations(), annotations)

# Pack the high level graph before sending it to the scheduler
keys = dsk.__dask_output_keys__()

# Validate keys
for key in keys:
validate_key(key)
Expand Down Expand Up @@ -3259,6 +3259,7 @@ def get(
dsk = TaskFactoryHLGWrapper.from_low_level(dsk, list(flatten(keys)))
futures = self._graph_to_futures(
dsk,
keys,
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
Expand Down Expand Up @@ -3449,6 +3450,7 @@ def compute(
if newstyle_collections(variables):
variables = [var.finalize_compute() for var in variables]
dsk = collections_to_dsk(variables, optimize_graph, **kwargs)
names = dsk.__dask_output_keys__()
else:
dsk = collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
Expand All @@ -3468,10 +3470,13 @@ def compute(
layers.update(dsk.layers)
dependencies = {finalize_name: set(dsk.layers.keys())}
dependencies.update(dsk.dependencies)
dsk = TaskFactoryHLGWrapper(HighLevelGraph(layers, dependencies), out_keys=names)
dsk = TaskFactoryHLGWrapper(
HighLevelGraph(layers, dependencies), out_keys=names
)

futures_dict = self._graph_to_futures(
dsk,
keys=names,
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
Expand All @@ -3481,7 +3486,7 @@ def compute(
actors=actors,
)

futures = list(futures_dict.values())
futures = [futures_dict[name] for name in names]

if sync:
result = self.gather(futures)
Expand Down Expand Up @@ -3570,6 +3575,7 @@ def persist(

futures = self._graph_to_futures(
dsk,
dsk.__dask_output_keys__(),
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
Expand Down
11 changes: 9 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from dask.core import get_deps, validate_key
from dask.typing import Key, no_default
from dask.utils import (
ensure_dict,
format_bytes,
format_time,
key_split,
Expand Down Expand Up @@ -140,7 +139,6 @@
from typing_extensions import TypeAlias

from dask.typing import TaskGraphFactory
from dask.highlevelgraph import HighLevelGraph

# Not to be confused with distributed.worker_state_machine.TaskStateState
TaskStateState: TypeAlias = Literal[
Expand Down Expand Up @@ -4670,6 +4668,15 @@ async def update_graph(
try:
try:
graph_factory = deserialize(graph_header, graph_frames).data
if list(graph_factory.__dask_output_keys__()) != list(keys):
# Note: If we no longer want to rely on tokenization to be
# consistent beyond an interpreter session, we could
# implement this in a way that the scheduler sends a signal
# back to the client with the final key names
# This would be easiest with a version of https://github.com/dask/distributed/issues/7480
raise RuntimeError(
"Catastrophic failure. Tokenization of keys unstable!"
)
del graph_header, graph_frames
except Exception as e:
msg = """\
Expand Down

0 comments on commit 99d8c4e

Please sign in to comment.