diff --git a/distributed/client.py b/distributed/client.py index ddd0359bd49..0b4214f2304 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -33,10 +33,9 @@ import dask from dask.base import collections_to_dsk, normalize_token, tokenize from dask.core import flatten, validate_key -from dask.delayed import Delayed from dask.highlevelgraph import HighLevelGraph from dask.optimization import SubgraphCallable -from dask.typing import DaskGraph, Key, NewDaskCollection +from dask.typing import DaskGraph, Key, all_collections_newstyle from dask.utils import ( apply, ensure_dict, @@ -3437,23 +3436,7 @@ def compute( else a for a in collections ) - - new_style = False - filtered_collections = [ - a - for a in collections - if dask.is_dask_collection(a) - # Delayed always returns True when checked if instance of a Protocol - and not isinstance(a, Delayed) - ] - if filtered_collections and all( - isinstance(a, NewDaskCollection) for a in filtered_collections - ): - new_style = True - elif any(isinstance(a, NewDaskCollection) for a in filtered_collections): - raise RuntimeError( - "Provided multiple collections to dask.compute but mixed old and new style collections." - ) + new_style = all_collections_newstyle(collections) if new_style: futures = [] for collection in collections: @@ -3606,21 +3589,7 @@ def persist( assert all(map(dask.is_dask_collection, collections)) - new_style = False - if all( - isinstance(a, NewDaskCollection) - for a in collections - if dask.is_dask_collection(a) and not isinstance(a, Delayed) - ): - new_style = True - elif any( - isinstance(a, NewDaskCollection) - for a in collections - if dask.is_dask_collection(a) - ): - raise RuntimeError( - "Provided multiple collections to dask.compute but mixed old and new style collections." - ) + new_style = all_collections_newstyle(collections) if new_style: result = []