Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Pass non-HLG objects wout materialization #7942

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

fjetter
Copy link
Member

@fjetter fjetter commented Jun 22, 2023

This is exploration and by no means intended to be merged. This concerns dask/dask-expr#14
which discusses to allow transferring an Expr to the scheduler using pickle w/out triggering materialization. I took a rather holistic approach to this with minimal knowledge about the state of dask-expr and wanted to explore the minimal API surface an object has to implement to be able to be passed to client.compute(obj) and get a result.

In a couple of places there are HLGs harcoded. To work around this I had to work a bit with the API and this is not in a nice state, yet.

I encountered a bit of trouble in various utility functions because of how __dask_graph__ is implemented and interpreted. In dask.dask this is currently understood as "Return the HLG object" and it is expected that every collection has this so this will never trigger a materialization. IIUC dask-expr currently uses this to return the materialized graph. I strongly recommend from moving away from "This is just a Mapping" interface.

To highlight the places where this happens I simply introduced __dask_graph_nomat__ that is a no-op in this case. This is clearly ugly and with a bit of work on dask/dask and dask-expr side I believe this can be avoided. However this will be dealt with, this may impact downstream users (xarray reported a problem about this before dask/dask#9058)

The other ugly part is actually how postcompute is dealt with. I tried not to modify a lot but the result is ugly and requires us building very awkward mutating API calls for that object. I wonder how dask-expr solved this (if it has). My approach now would be to just introduce an API called finalize_{compute|persist} and let the object deal with the mutation itself instead of offering very awkward API calls to do this externally.

Anyhow, the current state gives us this as an example w/ minimal API

class MyCustomScalar(DaskGraph):
    dsk = {
        "x": (inc, 1),
        "y": (inc, "x"),
        "z": (mysum, "x", "y"),
    }
    out_keys = ["z"]

    def __dask_graph__(self):
        # This is just for ad-hoc testing
        if _ALLOW_MATERIALIZATION.get():
            return self.dsk
        else:
            raise RuntimeError("Not allowed to materialize!")

    def __dask_graph_nomat__(self):
        return self

    def __dask_annotations__(self):
        return {"x": {"foo": "bar"}, "z": {"retries": 1}}

    def __dask_keys__(self):
        return self.out_keys

    # Everything below this is only necessary for
    # postcompute/postpersist/finalize foo
    # I think we'd be better served with a `finalize_compute` method instead of doing this weird mutation externally

    def __dask_postcompute__(self):
        # NOTE: first is dask.single_key which is a sentinel for client.compute
        return single_key, ()

    def __dask_postpersist__(self):
        raise NotImplementedError()

    def set_outkeys(self, keys):
        self.out_keys = list(keys)

    def merge(self, *others):
        # FIXME: This should create a new object but it's just a test...
        for ot in others:
            self.dsk.update(dict(ot))
        return self

dask/dask sibling dask/dask#10369

cc @mrocklin @rjzamora @phofl

@fjetter
Copy link
Member Author

fjetter commented Jun 22, 2023

FWIW as a next step I will try to throw away all this weird mutation in client.compute and just submit the entire thing to the scheduler. The current code assumes there to be a difference between Collection and AbstractGraph which may be fine but this distinction should only matter once we're on the scheduler, if ever.

@github-actions
Copy link
Contributor

github-actions bot commented Jun 22, 2023

Unit Test Results

See test report for an extended history of previous test failures. This is useful for diagnosing flaky tests.

7 files   -     20  7 suites   - 20   0s ⏱️ - 9h 47m 26s
0 tests  -  3 960  0 ✅  -  3 847  0 💤  -   109  0 ❌  - 4 
0 runs   - 49 809  0 ✅  - 47 515  0 💤  - 2 290  0 ❌  - 4 

Results for commit 99d8c4e. ± Comparison against base commit 33b2c72.

♻️ This comment has been updated with latest results.

@fjetter
Copy link
Member Author

fjetter commented Jun 22, 2023

We can actually remove a lot of boilerplate from the client side and just submit the collection. This would require us to subtly change how finalize/postpersist works but only marginally. Haven't checked how this interacts with dask.local schedulers but this is roughly how the minimal interface looks like after dropping unnecessary foo.

class DaskCollection(Protocol):
    def __dask_graph__(self) -> dict:
        """The materialized low level graph"""
        
    def __dask_annotations__(self) -> dict:
        """The materialized annotations per key"""

    def __dask_keys__(self) -> list[Hashable]:
        """The requested output keys"""

    def finalize_compute(self) -> DaskCollection:
        """Modify the graph such that it's output can be fetched by a client"""

    def postpersist(self, futures) -> DaskCollection:
        """Rebuild the graph with the futures provided. There has to be one future for every output key."""

Copy link
Member Author

@fjetter fjetter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of notes form a conversation with @phofl

distributed/tests/test_scheduler.py Outdated Show resolved Hide resolved
distributed/tests/test_scheduler.py Outdated Show resolved Hide resolved
distributed/tests/test_scheduler.py Outdated Show resolved Hide resolved
distributed/tests/test_scheduler.py Outdated Show resolved Hide resolved
distributed/tests/test_scheduler.py Outdated Show resolved Hide resolved
@mrocklin
Copy link
Member

Quick random thought:

Maybe we need to think bigger and reinvent the entire protocol. The old protocol is very much based on sending and receiving graphs (dict-like things). We're now sending and receiving things that can contain graphs. That's probably way different. I'm mildly suspicious of ways that try to maintain the old protocol.

@fjetter
Copy link
Member Author

fjetter commented Sep 8, 2023

I have a PR over at dask-expr that show what modifications would be necessary to adhere to this new protocol dask/dask-expr#294

@fjetter fjetter changed the title [RFC] Pass non-HLG objects wout materialization [WIP] Pass non-HLG objects wout materialization Sep 8, 2023
distributed/client.py Outdated Show resolved Hide resolved
fjetter

This comment was marked as outdated.

@fjetter fjetter force-pushed the dask_graph_protocol_wout_materialization branch from 4e615e6 to 5e87153 Compare December 20, 2023 15:07
@fjetter
Copy link
Member Author

fjetter commented Dec 20, 2023

I took a step back here and tried to not reuse any of the previously defined dunder methods and basically use new ones.

High level, the following entities exist (I'm not attached to names and just picked explicit names; we can iterate on that but please let us figure out semantics first)

class DaskCollection2(Protocol):

    @abc.abstractmethod
    def finalize_compute(self) -> DaskCollection2:
        ...
    @abc.abstractmethod
    def postpersist(self, futures: dict) -> DaskCollection2:
        ...
    def __dask_graph_factory__(self) -> TaskGraphFactory:
        ...
    compute/persist/visualize/tokenize

class TaskGraphFactory(Protocol):
    @abc.abstractmethod
    def combine_factories(self) -> TaskGraphFactory:
        ...

    @abc.abstractmethod
    def materialize(self) -> dict:
        ...

    @abc.abstractmethod
    def optimize(self) -> TaskGraphFactory:
        ...

    @abc.abstractmethod
    def __dask_output_keys__(self) -> NestedKeys:
        ...

    @abc.abstractmethod
    def get_annotations(self) -> dict:
        # TODO: This is not working with expr
        ...

Graph = dict + output keys
  • As usual, a DaskCollection2 is something like a DataFrame or Array. The DaskCollection2 does not have any knowledge about keys but is an abstract class that exposes the user API. It does also not have a concept of a graph as such, i.e. there is no dask property or __dask_graph__ method. It also doesn't know anything about optimizations. This will never be submitted to the scheduler.
  • The DaskCollection2 has a __dask_graph_factory__ method that is returning an object that satisfies the TaskGraphFactory, i.e. it is an object that can produce and manipulate task graphs. This is what we'll send to the scheduler. Expressions of dask-expr satisfy this protocol.
  • The task graph itself plus output keys. I'm spelling this out explicitly because I think this is also something that went wrong with HLGs since a HLG is technically not even sufficient to compute it. You also need the output keys. Both is produced by the factory.

Comment on lines 3453 to 3471
dsk = collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
dsk = dsk._hlg
dsk2 = {}
for i, (name, v) in enumerate(zip(names, variables)):
func, extra_args = v.__dask_postcompute__()
keys = v.__dask_keys__()
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
else:
dsk2[name] = (func, keys) + extra_args

dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
dsk2 = {}
for i, (name, v) in enumerate(zip(names, variables)):
func, extra_args = v.__dask_postcompute__()
keys = v.__dask_keys__()
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
else:
dsk2[name] = (func, keys) + extra_args

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())

# Let's append the finalize graph to dsk
finalize_name = tokenize(names)
layers = {finalize_name: dsk2}
layers.update(dsk.layers)
dependencies = {finalize_name: set(dsk.layers.keys())}
dependencies.update(dsk.dependencies)
dsk = HighLevelGraph(layers, dependencies)
# Let's append the finalize graph to dsk
finalize_name = tokenize(names)
layers = {finalize_name: dsk2}
layers.update(dsk.layers)
dependencies = {finalize_name: set(dsk.layers.keys())}
dependencies.update(dsk.dependencies)
dsk = TaskFactoryHLGWrapper(HighLevelGraph(layers, dependencies), out_keys=names)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The synchronous version is a little simpler and can be seen here https://github.com/dask/dask/pull/10369/files#r1432858069

Comment on lines +3460 to +3463
if func is single_key and len(keys) == 1 and not extra_args:
names[i] = keys[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a special case for Scalars that should not live here. With the new API, the task factory knows that it is a scalar and knows how to finalize itself.

Comment on lines +3581 to +3594
if newstyle_collections(collections):
result = [var.postpersist(futures) for var in collections]
else:
postpersists = [c.__dask_postpersist__() for c in collections]
result = [
func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args)
for (func, args), c in zip(postpersists, collections)
]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could keep the old __dask_postpersist__ protocol. I think reversing the control for this API is not as powerful as for postcompute but I would do it like this for symmetry reasons. The only caveat here is that the new collection API technically doesn't know anything about keys so we have to pass all futures to this API and trust the collection to do the right thing

@mrocklin
Copy link
Member

Thanks for writing this up @fjetter . Comments in-line.

class DaskCollection2(Protocol):

    @abc.abstractmethod
    def finalize_compute(self) -> DaskCollection2:
        ...
    @abc.abstractmethod
    def postpersist(self, futures: dict) -> DaskCollection2:
        ...

    # I'm mildly alergic to the word factory, probably from the days over over-engineer OO systems.  I'd welcome a different name and am happy to work on finding one if people like.
    def __dask_graph_factory__(self) -> TaskGraphFactory:
        ...
    compute/persist/visualize/tokenize

class TaskGraphFactory(Protocol):
    # I assume that this takes in `*args` and returns something like a `Tuple` object you've made before?
    @abc.abstractmethod
    def combine_factories(self) -> TaskGraphFactory:
        ...

    # This is maybe the current `__dask_graph__` protocol that we have?  If so, my recommendation would be to keep that for continuity's sake
    @abc.abstractmethod
    def materialize(self) -> dict:
        ...

    @abc.abstractmethod
    def optimize(self) -> TaskGraphFactory:
        ...

    # Same as above, thoughts on keeping this `__dask_keys__`?
    @abc.abstractmethod
    def __dask_output_keys__(self) -> NestedKeys:
        ...

    @abc.abstractmethod
    def get_annotations(self) -> dict:
        # TODO: This is not working with expr
        ...

Graph = dict + output keys

@fjetter
Copy link
Member Author

fjetter commented Dec 20, 2023

# I'm mildly alergic to the word factory, probably from the days over over-engineer OO systems.  I'd welcome a different name and am happy to work on finding one if people like.

Yeah... I agree. I chose the word since I figured it made it easy for us to talk about semantics. This is more or less a textbook definition of a factory so I figured it would help us take a step back from "graphs, mappings, expressions". Very open to iterate.

# I assume that this takes in `*args` and returns something like a `Tuple` object you've made before?

Yes, I'm also not certain if this is necessary. It will be useful if we tried to move the old-style stuff to the new interface but otherwise we can just be explicit with Tuple

This is maybe the current __dask_graph__ protocol that we have? If so, my recommendation would be to keep that for continuity's sake

I actually tried to avoid this because I think it's been misused. I found this being called in many places with the expectation to receive a HighLevelGraph.
If you insist, we can reuse it but I wouldn't mind starting fresh

# Same as above, thoughts on keeping this `__dask_keys__`?

I think that'd be fine

@mrocklin
Copy link
Member

I'm curious to learn more about the __dask_graph__ thing. Where was this being misused?

@fjetter
Copy link
Member Author

fjetter commented Dec 20, 2023

Well, for starters, I consider our current is_dask_collection check a misuse https://github.com/dask/dask/blob/f6115511a09c309e02c2d40946b6ca74c6d148db/dask/base.py#L192-L216
That's perfectly fine for HLGs but not for Expr if it generates the low level graph. The same is true for most of the code I am touching in dask/dask around collections_to_dsk. We will have to touch that code anyhow so we can fix these places but it is really hard to tell where this is happening elsewhere.

I think what's bothering me the most is that the current semantics of __dask_graph__ are a little ill defined because of this. Do I get a dictionary with runnable tasks? Do I get a HighLevelGraph? Judging from dask.typing I could even expect a Layer! (which is thankfully never happening but we have the Graph type defined as such and I suggest throwing this out, see the deprecations in the dask/dask PR).

Another thing is that the __dask_graph__ is currently defined on the collection while in the future it will be defined on the TaskFactory(Pending new name). Reusing the same method on a different entity can cause all sorts of confusion (it did for me initially) especially if we continue to rely on duck-typing.

Since semantics change I find a clean cut simpler for migration.

@fjetter
Copy link
Member Author

fjetter commented Dec 20, 2023

FWIW I also introduced a different method for __dask_keys__ because it is defined on a different entity. It's no longer the collection but the Expr / TaskBackend/Factory. I don't think collections should even know about keys. To avoid future misuse and confusion with duck typing, I chose a different method name

@mrocklin
Copy link
Member

Well, for starters, I consider our current is_dask_collection check a misuse

If most of the misuse is internal then I think it's also fine (and maybe better) for us to just correct the misuse.

My guess is that it's easier to implement things by making an entirely new method name, but that the end result might not be as good. The intention of __dask_graph__ is to return a classic Dask task graph. Implementation details like HLG, dict, etc. seem like crap that we've done to ourselves over the years. I think that the intention of __dask_graph__ is still what we want. We should just do a good job of it now.

I don't feel very strongly about this, and can be overruled. If we're going to add a new protocol for this though then let's certainly delete the old one. My guess is that once you do all of the work where you're in a place to delete the old thing then it might make more sense to rename back to __dask_graph__. My guess is that the desire to find a new name is coming mostly from doing the implementation (it's easier to be sure about something when you write it fresh) but that after the implementation is done then reusing the old name for this concept will make more sense.

@fjetter
Copy link
Member Author

fjetter commented Dec 21, 2023

As I said above, at this point I'm mostly concerned about semantics and not about names. We can rename everything in the end and I don't care strongly about the literal name of things. If there is no further feedback about the semantics, I will move forward with this and we can name it to whatever we want once the code settles.


One thing dask/dask#10676 reminded me of is that xarray requires some sort of dynamic way to set/unset whether an object is supposed to be considered a dask collection (@dcherian is this statement correct?). If that is the case, I suggest to introduce an additional method that is returning a bool __dask_is_collection__ which allows to overwrite behavior of dask.is_dask_collection. This would also remove the need to have the kind of inference in dask as suggested in dask/dask#10676

@mrocklin
Copy link
Member

As I said above, at this point I'm mostly concerned about semantics and not about names. We can rename everything in the end and I don't care strongly about the literal name of things. If there is no further feedback about the semantics, I will move forward with this and we can name it to whatever we want once the code settles.

Yup. Fine by me. It sounds like we're mostly aligned there.

@crusaderky
Copy link
Collaborator

Some comments inline below

class DaskCollection2(Protocol):

    # I can't figure out what this method is supposed to do.
    # Could you make an example that isn't a trivial `return self`?
    # Also since no end user will call this, I think it should be
    # renamed to `__dask_finalize_compute__`.
    @abc.abstractmethod
    def finalize_compute(self) -> DaskCollection2:
        ...

    # I feel strongly against this. I'm elaborating below.
    # Again, IMHO it should be renamed to `__dask_postpersist__`.
    @abc.abstractmethod
    def postpersist(self, futures: dict) -> DaskCollection2:
        ...

    # I suggest to change this to `__dask_graph__`, and allow it to return None
    # if this is not a dask collection. Explicitly stipulate that it must be trivially expensive to call.
    # This means that this same method in legacy collections will return a Mapping[Key, Any],
    # while new-style collections will return a TaskGraphFactory, while xarray will return a
    # TaskGraphFactory | None depending on what it's wrapping around.
    def __dask_graph_factory__(self) -> TaskGraphFactory:
        ...

    # nit: compute/persist/visualize should not be part of the protocol. They should be part of a
    # convenience superclass that implements the protocol and should be trivial one-liners
    # `return dask.compute(self)`.
    # Again, IMHO tokenize should be renamed to `__dask_tokenize__`.
    compute/persist/visualize/tokenize

# I'd prefer DaskGraphFactory, chiefly for branding reasons. Also we may not stick to the concept
# of tasks forever. Tasks as a founding design principle have already become quite shaky with p2p shuffle.
class TaskGraphFactory(Protocol):

About my problem with postpersist:
Let's take as an example an xarray.Dataset, that wraps around multiple dask_expr.array.Array objects.

class Dataset:
    def __dask_graph_factory__(self) -> TaskGraphFactory | None:
        factories = [v.__dask_graph_factory__() for v in self.variables.values() if is_dask_collection(v)]
        if not factories:
            return None
        elif len(factories) == 1:
            return factories[0]
        else:
            return factories[0].combine_factories(*factories[1:])

So far so good. However, xarray.Dataset no longer knows anything about dask keys.
So it won't know what to do when dask calls its postpersist(self, futures: dict) method with a bunch of dask keys, as it has no idea how to split them back and call its underlying dask_expr.array.Array.postpersist() methods.
It could call v.__dask_graph_factory__().__dask_output_keys__() but in dask_expr I suspect that may trigger materialization too?
For obvious sanity reasons, I would not want to stipulate that postpersist() should quietly ignore futures it knows nothing about, either.

@crusaderky
Copy link
Collaborator

Figured out finalize_compute. We should get rid of it. My comment here: https://github.com/dask-contrib/dask-expr/pull/294/files#r1435000385

@fjetter fjetter force-pushed the dask_graph_protocol_wout_materialization branch from 7de6139 to 99d8c4e Compare January 16, 2024 12:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants