Skip to content

Commit

Permalink
foo
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 20, 2023
1 parent 4d41d32 commit 5e87153
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 78 deletions.
106 changes: 51 additions & 55 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@
from tlz import first, groupby, merge, partition_all, valmap

import dask
from dask.base import collections_to_dsk, normalize_token, tokenize
from dask.base import (
collections_to_dsk,
normalize_token,
tokenize,
newstyle_collections,
)
from dask.core import flatten, validate_key
from dask.highlevelgraph import HighLevelGraph
from dask.highlevelgraph import HighLevelGraph, TaskFactoryHLGWrapper
from dask.optimization import SubgraphCallable
from dask.typing import no_default
from dask.typing import no_default, DaskCollection2
from dask.utils import (
apply,
ensure_dict,
Expand Down Expand Up @@ -1957,10 +1962,10 @@ def submit(
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {key: (func,) + tuple(args)}
dsk = TaskFactoryHLGWrapper.from_low_level(dsk, [key])

futures = self._graph_to_futures(
dsk,
[key],
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority={key: 0},
Expand Down Expand Up @@ -2166,7 +2171,6 @@ def map(

futures = self._graph_to_futures(
dsk,
keys,
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority=internal_priority,
Expand Down Expand Up @@ -3105,7 +3109,6 @@ def _get_computation_code(
def _graph_to_futures(
self,
dsk,
keys,
workers=None,
allow_other_workers=None,
internal_priority=None,
Expand All @@ -3119,10 +3122,6 @@ def _graph_to_futures(
if actors is not None and actors is not True and actors is not False:
actors = list(self._expand_key(actors))

# Make sure `dsk` is a high level graph
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())

annotations = {}
if user_priority:
annotations["priority"] = user_priority
Expand All @@ -3143,14 +3142,14 @@ def _graph_to_futures(
annotations = merge(dask.get_annotations(), annotations)

# Pack the high level graph before sending it to the scheduler
keyset = set(keys)
keys = dsk.__dask_output_keys__()

# Validate keys
for key in keyset:
for key in keys:
validate_key(key)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
futures = {key: Future(key, self, inform=False) for key in keys}
# Circular import
from distributed.protocol import serialize
from distributed.protocol.serialize import ToPickle
Expand All @@ -3173,7 +3172,7 @@ def _graph_to_futures(
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(keys),
"keys": keys,
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
Expand Down Expand Up @@ -3257,9 +3256,9 @@ def get(
--------
Client.compute : Compute asynchronous collections
"""
dsk = TaskFactoryHLGWrapper.from_low_level(dsk, list(flatten(keys)))
futures = self._graph_to_futures(
dsk,
keys=set(flatten([keys])),
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
Expand Down Expand Up @@ -3447,32 +3446,32 @@ def compute(
)

variables = [a for a in collections if dask.is_dask_collection(a)]
if newstyle_collections(variables):
variables = [var.finalize_compute() for var in variables]
dsk = collections_to_dsk(variables, optimize_graph, **kwargs)
else:
dsk = collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
dsk = dsk._hlg
dsk2 = {}
for i, (name, v) in enumerate(zip(names, variables)):
func, extra_args = v.__dask_postcompute__()
keys = v.__dask_keys__()
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
else:
dsk2[name] = (func, keys) + extra_args

dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
dsk2 = {}
for i, (name, v) in enumerate(zip(names, variables)):
func, extra_args = v.__dask_postcompute__()
keys = v.__dask_keys__()
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
else:
dsk2[name] = (func, keys) + extra_args

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())

# Let's append the finalize graph to dsk
finalize_name = tokenize(names)
layers = {finalize_name: dsk2}
layers.update(dsk.layers)
dependencies = {finalize_name: set(dsk.layers.keys())}
dependencies.update(dsk.dependencies)
dsk = HighLevelGraph(layers, dependencies)
# Let's append the finalize graph to dsk
finalize_name = tokenize(names)
layers = {finalize_name: dsk2}
layers.update(dsk.layers)
dependencies = {finalize_name: set(dsk.layers.keys())}
dependencies.update(dsk.dependencies)
dsk = TaskFactoryHLGWrapper(HighLevelGraph(layers, dependencies), out_keys=names)

futures_dict = self._graph_to_futures(
dsk,
names,
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
Expand All @@ -3482,23 +3481,18 @@ def compute(
actors=actors,
)

i = 0
futures = []
for arg in collections:
if dask.is_dask_collection(arg):
futures.append(futures_dict[names[i]])
i += 1
else:
futures.append(arg)
futures = list(futures_dict.values())

if sync:
result = self.gather(futures)
else:
result = futures

if singleton:
assert len(result) == 1
return first(result)
else:
assert len(result) > 1
return result

def persist(
Expand Down Expand Up @@ -3572,13 +3566,10 @@ def persist(

assert all(map(dask.is_dask_collection, collections))

dsk = self.collections_to_dsk(collections, optimize_graph, **kwargs)

names = {k for c in collections for k in flatten(c.__dask_keys__())}
dsk = collections_to_dsk(collections, optimize_graph, **kwargs)

futures = self._graph_to_futures(
dsk,
names,
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
Expand All @@ -3587,12 +3578,14 @@ def persist(
fifo_timeout=fifo_timeout,
actors=actors,
)

postpersists = [c.__dask_postpersist__() for c in collections]
result = [
func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args)
for (func, args), c in zip(postpersists, collections)
]
if newstyle_collections(collections):
result = [var.postpersist(futures) for var in collections]
else:
postpersists = [c.__dask_postpersist__() for c in collections]
result = [
func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args)
for (func, args), c in zip(postpersists, collections)
]

if singleton:
return first(result)
Expand Down Expand Up @@ -4700,6 +4693,7 @@ def _expand_key(cls, k):
@staticmethod
def collections_to_dsk(collections, *args, **kwargs):
"""Convert many collections into a single dask graph, after optimization"""
warnings.warn(DeprecationWarning, "Why are you using this??")
return collections_to_dsk(collections, *args, **kwargs)

async def _story(self, *keys_or_stimuli: str, on_error="raise"):
Expand Down Expand Up @@ -5865,6 +5859,8 @@ def futures_of(o, client=None):
if x not in seen:
seen.add(x)
futures.append(x)
elif isinstance(x, DaskCollection2):
stack.extend(x.__dask_graph_factory__().materialize().values())
elif dask.is_dask_collection(x):
stack.extend(x.__dask_graph__().values())

Expand Down
23 changes: 10 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias

from dask.typing import TaskGraphFactory
from dask.highlevelgraph import HighLevelGraph

# Not to be confused with distributed.worker_state_machine.TaskStateState
Expand Down Expand Up @@ -4640,7 +4641,7 @@ async def update_graph(
client: str,
graph_header: dict,
graph_frames: list[bytes],
keys: set[Key],
keys: list[Key],
internal_priority: dict[Key, int] | None,
submitting_task: Key | None,
user_priority: int | dict[Key, int] = 0,
Expand All @@ -4656,7 +4657,7 @@ async def update_graph(
start = time()
try:
try:
graph = deserialize(graph_header, graph_frames).data
graph_factory = deserialize(graph_header, graph_frames).data
del graph_header, graph_frames
except Exception as e:
msg = """\
Expand All @@ -4672,10 +4673,10 @@ async def update_graph(
annotations_by_type,
) = await offload(
_materialize_graph,
graph=graph,
graph_factory=graph_factory,
global_annotations=annotations or {},
)
del graph
del graph_factory
if not internal_priority:
# Removing all non-local keys before calling order()
dsk_keys = set(
Expand Down Expand Up @@ -8653,24 +8654,20 @@ def transition(


def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict[str, Any]
graph_factory: TaskGraphFactory, global_annotations: dict[str, Any]
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
dsk = ensure_dict(graph)
dsk = graph_factory.materialize()
for k in dsk:
validate_key(k)
annotations_by_type: defaultdict[str, dict[Key, Any]] = defaultdict(dict)
for annotations_type, value in global_annotations.items():
annotations_by_type[annotations_type].update(
{k: (value(k) if callable(value) else value) for k in dsk}
)
graph_annotations = graph_factory.get_annotations()

for layer in graph.layers.values():
if layer.annotations:
annot = layer.annotations
for annot_type, value in annot.items():
annotations_by_type[annot_type].update(
{k: (value(k) if callable(value) else value) for k in layer}
)
for k, v in graph_annotations.items():
annotations_by_type[k].update(v)
dependencies, _ = get_deps(dsk)

# Remove `Future` objects from graph and note any future dependencies
Expand Down
18 changes: 9 additions & 9 deletions distributed/tests/test_dask_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ async def test_dataframes(c, s, a, b):
)
ldf = dd.from_pandas(df, npartitions=10)

rdf = await c.persist(ldf)
assert rdf.divisions == ldf.divisions
# rdf = await c.persist(ldf)
# assert rdf.divisions == ldf.divisions

remote = c.compute(rdf)
result = await remote
# remote = c.compute(rdf)
# result = await remote

assert_frame_equal(result, ldf.compute(scheduler="sync"))
# assert_frame_equal(result, ldf.compute(scheduler="sync"))

exprs = [
lambda df: df.x.mean(),
Expand All @@ -68,10 +68,10 @@ async def test_dataframes(c, s, a, b):
lambda df: df.loc[50:75],
]
for f in exprs:
local = f(ldf).compute(scheduler="sync")
remote = c.compute(f(rdf))
remote = await remote
assert_equal(local, remote)
local = await c.gather(c.compute(f(ldf)))
# remote = c.compute(f(rdf))
# remote = await remote
# assert_equal(local, remote)


@ignore_single_machine_warning
Expand Down
Loading

0 comments on commit 5e87153

Please sign in to comment.