diff --git a/distributed/client.py b/distributed/client.py index 46982203530..29aa264f32d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4,6 +4,7 @@ import atexit import copy import inspect +import itertools import json import logging import os @@ -31,7 +32,7 @@ from tlz import first, groupby, merge, partition_all, valmap import dask -from dask.base import collections_to_dsk, normalize_token, tokenize +from dask.base import collections_to_dsk, tokenize from dask.core import flatten, validate_key from dask.highlevelgraph import HighLevelGraph from dask.optimization import SubgraphCallable @@ -210,11 +211,15 @@ class Future(WrappedKey): _cb_executor = None _cb_executor_pid = None + _counter = itertools.count() + # Make sure this stays unique even across multiple processes or hosts + _uid = uuid.uuid4().hex - def __init__(self, key, client=None, inform=True, state=None): + def __init__(self, key, client=None, inform=True, state=None, _id=None): self.key = key self._cleared = False self._client = client + self._id = _id or (Future._uid, next(Future._counter)) self._input_state = state self._inform = inform self._state = None @@ -499,8 +504,16 @@ def release(self): except TypeError: # pragma: no cover pass # Shutting down, add_callback may be None + @staticmethod + def make_future(key, id): + # Can't use kwargs in pickle __reduce__ methods + return Future(key=key, _id=id) + def __reduce__(self) -> str | tuple[Any, ...]: - return Future, (self.key,) + return Future.make_future, (self.key, self._id) + + def __dask_tokenize__(self): + return (type(self).__name__, self.key, self._id) def __del__(self): try: @@ -643,18 +656,6 @@ async def done_callback(future, callback): callback(future) -@partial(normalize_token.register, Future) -def normalize_future(f): - """Returns the key and the type as a list - - Parameters - ---------- - list - The key and the type - """ - return [f.key, type(f)] - - class AllExit(Exception): """Custom exception class to exit All(...) early.""" @@ -3434,9 +3435,11 @@ def compute( if traverse: collections = tuple( - dask.delayed(a) - if isinstance(a, (list, set, tuple, dict, Iterator)) - else a + ( + dask.delayed(a) + if isinstance(a, (list, set, tuple, dict, Iterator)) + else a + ) for a in collections ) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a214bfae840..b30b681d38a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -863,11 +863,13 @@ async def test_tokenize_on_futures(c, s, a, b): y = c.submit(inc, 1) tok = tokenize(x) assert tokenize(x) == tokenize(x) - assert tokenize(x) == tokenize(y) + # Tokens must be unique per instance + # See https://github.com/dask/distributed/issues/8561 + assert tokenize(x) != tokenize(y) c.futures[x.key].finish() - assert tok == tokenize(y) + assert tok != tokenize(y) @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")