Skip to content

Remove recursion in task spec (#8920) #2504

Remove recursion in task spec (#8920)

Remove recursion in task spec (#8920) #2504

GitHub Actions / Unit Test Results failed Nov 20, 2024 in 0s

46 fail, 110 skipped, 3 974 pass in 10h 12m 37s

    25 files      25 suites   10h 12m 37s ⏱️
 4 130 tests  3 974 ✅   110 💤  46 ❌
47 692 runs  45 130 ✅ 2 121 💤 441 ❌

Results for commit 750cb91.

Annotations

Check warning on line 0 in distributed.tests.test_client

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

11 out of 12 runs failed: test_persist_async (distributed.tests.test_client)

artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39653', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:37647', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:44433', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>

    @gen_cluster(client=True)
    async def test_persist_async(c, s, a, b):
        pytest.importorskip("numpy")
        da = pytest.importorskip("dask.array")
    
        x = da.ones((10, 10), chunks=(5, 10))
        y = 2 * (x + 1)
        assert len(y.dask) == 6
        yy = c.persist(y)
    
        assert len(y.dask) == 6
        assert len(yy.dask) == 2
        assert all(isinstance(v, Future) for v in yy.dask.values())
        assert yy.__dask_keys__() == y.__dask_keys__()
    
        g, h = c.compute([y, yy])
    
>       gg, hh = await c.gather([g, h])

distributed/tests/test_client.py:2565: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:2427: in _gather
    raise exception.with_traceback(traceback)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1328: in finalize
    return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5496: in concatenate3
    chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: in chunks_from_arrays
    result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import contextlib
    import math
    import operator
    import os
    import pickle
    import re
    import sys
    import traceback
    import uuid
    import warnings
    from bisect import bisect
    from collections import defaultdict
    from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
    from functools import lru_cache, partial, reduce, wraps
    from itertools import product, zip_longest
    from numbers import Integral, Number
    from operator import add, mul
    from threading import Lock
    from typing import Any, Literal, TypeVar, Union, cast
    
    import numpy as np
    from numpy.typing import ArrayLike
    from packaging.version import Version
    from tlz import accumulate, concat, first, groupby, partition
    from tlz.curried import pluck
    from toolz import frequencies
    
    from dask import compute, config, core
    from dask.array import chunk
    from dask.array.chunk import getitem
    from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
    
    # Keep einsum_lookup and tensordot_lookup here for backwards compatibility
    from dask.array.dispatch import (  # noqa: F401
        concatenate_lookup,
        einsum_lookup,
        tensordot_lookup,
    )
    from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
    from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
    from dask.array.utils import compute_meta, meta_from_array
    from dask.base import (
        DaskMethodsMixin,
        compute_as_if_collection,
        dont_optimize,
        is_dask_collection,
        named_schedulers,
        persist,
        tokenize,
    )
    from dask.blockwise import blockwise as core_blockwise
    from dask.blockwise import broadcast_dimensions
    from dask.context import globalmethod
    from dask.core import quote
    from dask.delayed import Delayed, delayed
    from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
    from dask.layers import ArraySliceDep, reshapelist
    from dask.sizeof import sizeof
    from dask.typing import Graph, Key, NestedKeys
    from dask.utils import (
        IndexCallable,
        SerializableLock,
        cached_cumsum,
        cached_property,
        concrete,
        derived_from,
        format_bytes,
        funcname,
        has_keyword,
        is_arraylike,
        is_dataframe_like,
        is_index_like,
        is_integer,
        is_series_like,
        maybe_pluralize,
        ndeepmap,
        ndimlist,
        parse_bytes,
        typename,
    )
    from dask.widgets import get_template
    
    T_IntOrNaN = Union[int, float]  # Should be Union[int, Literal[np.nan]]
    
    DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
    
    unknown_chunk_message = (
        "\n\n"
        "A possible solution: "
        "https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
        "Summary: to compute chunks sizes, use\n\n"
        "   x.compute_chunk_sizes()  # for Dask Array `x`\n"
        "   ddf.to_dask_array(lengths=True)  # for Dask DataFrame `ddf`"
    )
    
    
    class PerformanceWarning(Warning):
        """A warning given when bad chunking may cause poor performance"""
    
    
    def getter(a, b, asarray=True, lock=None):
        if isinstance(b, tuple) and any(x is None for x in b):
            b2 = tuple(x for x in b if x is not None)
            b3 = tuple(
                None if x is None else slice(None, None)
                for x in b
                if not isinstance(x, Integral)
            )
            return getter(a, b2, asarray=asarray, lock=lock)[b3]
    
        if lock:
            lock.acquire()
        try:
            c = a[b]
            # Below we special-case `np.matrix` to force a conversion to
            # `np.ndarray` and preserve original Dask behavior for `getter`,
            # as for all purposes `np.matrix` is array-like and thus
            # `is_arraylike` evaluates to `True` in that case.
            if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
                c = np.asarray(c)
        finally:
            if lock:
                lock.release()
        return c
    
    
    def getter_nofancy(a, b, asarray=True, lock=None):
        """A simple wrapper around ``getter``.
    
        Used to indicate to the optimization passes that the backend doesn't
        support fancy indexing.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    def getter_inline(a, b, asarray=True, lock=None):
        """A getter function that optimizations feel comfortable inlining
    
        Slicing operations with this function may be inlined into a graph, such as
        in the following rewrite
    
        **Before**
    
        >>> a = x[:10]  # doctest: +SKIP
        >>> b = a + 1  # doctest: +SKIP
        >>> c = a * 2  # doctest: +SKIP
    
        **After**
    
        >>> b = x[:10] + 1  # doctest: +SKIP
        >>> c = x[:10] * 2  # doctest: +SKIP
    
        This inlining can be relevant to operations when running off of disk.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    from dask.array.optimization import fuse_slice, optimize
    
    # __array_function__ dict for mapping aliases and mismatching names
    _HANDLED_FUNCTIONS = {}
    
    
    def implements(*numpy_functions):
        """Register an __array_function__ implementation for dask.array.Array
    
        Register that a function implements the API of a NumPy function (or several
        NumPy functions in case of aliases) which is handled with
        ``__array_function__``.
    
        Parameters
        ----------
        \\*numpy_functions : callables
            One or more NumPy functions that are handled by ``__array_function__``
            and will be mapped by `implements` to a `dask.array` function.
        """
    
        def decorator(dask_func):
            for numpy_function in numpy_functions:
                _HANDLED_FUNCTIONS[numpy_function] = dask_func
    
            return dask_func
    
        return decorator
    
    
    def _should_delegate(self, other) -> bool:
        """Check whether Dask should delegate to the other.
        This implementation follows NEP-13:
        https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
        """
        if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
            return True
        elif (
            hasattr(other, "__array_ufunc__")
            and not is_valid_array_chunk(other)
            # don't delegate to our own parent classes
            and not isinstance(self, type(other))
            and type(self) is not type(other)
        ):
            return True
        return False
    
    
    def check_if_handled_given_other(f):
        """Check if method is handled by Dask given type of other
    
        Ensures proper deferral to upcast types in dunder operations without
        assuming unknown types are automatically downcast types.
        """
    
        @wraps(f)
        def wrapper(self, other):
            if _should_delegate(self, other):
                return NotImplemented
            else:
                return f(self, other)
    
        return wrapper
    
    
    def slices_from_chunks(chunks):
        """Translate chunks tuple to a set of slices in product order
    
        >>> slices_from_chunks(((2, 2), (3, 3, 3)))  # doctest: +NORMALIZE_WHITESPACE
         [(slice(0, 2, None), slice(0, 3, None)),
          (slice(0, 2, None), slice(3, 6, None)),
          (slice(0, 2, None), slice(6, 9, None)),
          (slice(2, 4, None), slice(0, 3, None)),
          (slice(2, 4, None), slice(3, 6, None)),
          (slice(2, 4, None), slice(6, 9, None))]
        """
        cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
        slices = [
            [slice(s, s + dim) for s, dim in zip(starts, shapes)]
            for starts, shapes in zip(cumdims, chunks)
        ]
        return list(product(*slices))
    
    
    def graph_from_arraylike(
        arr,  # Any array-like which supports slicing
        chunks,
        shape,
        name,
        getitem=getter,
        lock=False,
        asarray=True,
        dtype=None,
        inline_array=False,
    ) -> HighLevelGraph:
        """
        HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
    
        If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
        array-like is embedded into every task.,
    
        If ``inline_array`` is False, this inserts the array-like as a standalone value in
        a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
        to it.
    
        >>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True))  # doctest: +SKIP
        {(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
         (arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
         (arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
         (arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
    
        >>> dict(  # doctest: +SKIP
                graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
            )
        {"original-X": arr,
         ('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
         ('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
         ('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
         ('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
        """
        chunks = normalize_chunks(chunks, shape, dtype=dtype)
        out_ind = tuple(range(len(shape)))
    
        if (
            has_keyword(getitem, "asarray")
            and has_keyword(getitem, "lock")
            and (not asarray or lock)
        ):
            kwargs = {"asarray": asarray, "lock": lock}
        else:
            # Common case, drop extra parameters
            kwargs = {}
    
        if inline_array:
            layer = core_blockwise(
                getitem,
                name,
                out_ind,
                arr,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
            return HighLevelGraph.from_collections(name, layer)
        else:
            original_name = "original-" + name
    
            layers = {}
            layers[original_name] = MaterializedLayer({original_name: arr})
            layers[name] = core_blockwise(
                getitem,
                name,
                out_ind,
                original_name,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
    
            deps = {
                original_name: set(),
                name: {original_name},
            }
            return HighLevelGraph(layers, deps)
    
    
    def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
        """Dot product of many aligned chunks
    
        >>> x = np.array([[1, 2], [1, 2]])
        >>> y = np.array([[10, 20], [10, 20]])
        >>> dotmany([x, x, x], [y, y, y])
        array([[ 90, 180],
               [ 90, 180]])
    
        Optionally pass in functions to apply to the left and right chunks
    
        >>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
        array([[150, 150],
               [150, 150]])
        """
        if leftfunc:
            A = map(leftfunc, A)
        if rightfunc:
            B = map(rightfunc, B)
        return sum(map(partial(np.dot, **kwargs), A, B))
    
    
    def _concatenate2(arrays, axes=None):
        """Recursively concatenate nested lists of arrays along axes
    
        Each entry in axes corresponds to each level of the nested list.  The
        length of axes should correspond to the level of nesting of arrays.
        If axes is an empty list or tuple, return arrays, or arrays[0] if
        arrays is a list.
    
        >>> x = np.array([[1, 2], [3, 4]])
        >>> _concatenate2([x, x], axes=[0])
        array([[1, 2],
               [3, 4],
               [1, 2],
               [3, 4]])
    
        >>> _concatenate2([x, x], axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        >>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4],
               [1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Supports Iterators
        >>> _concatenate2(iter([x, x]), axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Special Case
        >>> _concatenate2([x, x], axes=())
        array([[1, 2],
               [3, 4]])
        """
        if axes is None:
            axes = []
    
        if axes == ():
            if isinstance(arrays, list):
                return arrays[0]
            else:
                return arrays
    
        if isinstance(arrays, Iterator):
            arrays = list(arrays)
        if not isinstance(arrays, (list, tuple)):
            return arrays
        if len(axes) > 1:
            arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
        concatenate = concatenate_lookup.dispatch(
            type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        if isinstance(arrays[0], dict):
            # Handle concatenation of `dict`s, used as a replacement for structured
            # arrays when that's not supported by the array library (e.g., CuPy).
            keys = list(arrays[0].keys())
            assert all(list(a.keys()) == keys for a in arrays)
            ret = dict()
            for k in keys:
                ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
            return ret
        else:
            return concatenate(arrays, axis=axes[0])
    
    
    def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
        """
        Tries to infer output dtype of ``func`` for a small set of input arguments.
    
        Parameters
        ----------
        func: Callable
            Function for which output dtype is to be determined
    
        args: List of array like
            Arguments to the function, which would usually be used. Only attributes
            ``ndim`` and ``dtype`` are used.
    
        kwargs: dict
            Additional ``kwargs`` to the ``func``
    
        funcname: String
            Name of calling function to improve potential error messages
    
        suggest_dtype: None/False or String
            If not ``None`` adds suggestion to potential error message to specify a dtype
            via the specified kwarg. Defaults to ``'dtype'``.
    
        nout: None or Int
            ``None`` if function returns single output, integer if many.
            Defaults to ``None``.
    
        Returns
        -------
        : dtype or List of dtype
            One or many dtypes (depending on ``nout``)
        """
        from dask.array.utils import meta_from_array
    
        # make sure that every arg is an evaluated array
        args = [
            (
                np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
                if is_arraylike(x)
                else x
            )
            for x in args
        ]
        try:
            with np.errstate(all="ignore"):
                o = func(*args, **kwargs)
        except Exception as e:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            tb = "".join(traceback.format_tb(exc_traceback))
            suggest = (
                (
                    "Please specify the dtype explicitly using the "
                    "`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
                )
                if suggest_dtype
                else ""
            )
            msg = (
                f"`dtype` inference failed in `{funcname}`.\n\n"
                f"{suggest}"
                "Original error is below:\n"
                "------------------------\n"
                f"{e!r}\n\n"
                "Traceback:\n"
                "---------\n"
                f"{tb}"
            )
        else:
            msg = None
        if msg is not None:
            raise ValueError(msg)
        return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
    
    
    def normalize_arg(x):
        """Normalize user provided arguments to blockwise or map_blocks
    
        We do a few things:
    
        1.  If they are string literals that might collide with blockwise_token then we
            quote them
        2.  IF they are large (as defined by sizeof) then we put them into the
            graph on their own by using dask.delayed
        """
        if is_dask_collection(x):
            return x
        elif isinstance(x, str) and re.match(r"_\d+", x):
            return delayed(x)
        elif isinstance(x, list) and len(x) >= 10:
            return delayed(x)
        elif sizeof(x) > 1e6:
            return delayed(x)
        else:
            return x
    
    
    def _pass_extra_kwargs(func, keys, *args, **kwargs):
        """Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
    
        For each element of `keys`, a corresponding element of args is changed
        to a keyword argument with that key, before all arguments re passed on
        to `func`.
        """
        kwargs.update(zip(keys, args))
        return func(*args[len(keys) :], **kwargs)
    
    
    def map_blocks(
        func,
        *args,
        name=None,
        token=None,
        dtype=None,
        chunks=None,
        drop_axis=None,
        new_axis=None,
        enforce_ndim=False,
        meta=None,
        **kwargs,
    ):
        """Map a function across all blocks of a dask array.
    
        Note that ``map_blocks`` will attempt to automatically determine the output
        array type by calling ``func`` on 0-d versions of the inputs. Please refer to
        the ``meta`` keyword argument below if you expect that the function will not
        succeed when operating on 0-d arrays.
    
        Parameters
        ----------
        func : callable
            Function to apply to every block in the array.
            If ``func`` accepts ``block_info=`` or ``block_id=``
            as keyword arguments, these will be passed dictionaries
            containing information about input and output chunks/arrays
            during computation. See examples for details.
        args : dask arrays or other objects
        dtype : np.dtype, optional
            The ``dtype`` of the output array. It is recommended to provide this.
            If not provided, will be inferred by applying the function to a small
            set of fake data.
        chunks : tuple, optional
            Chunk shape of resulting blocks if the function does not preserve
            shape. If not provided, the resulting array is assumed to have the same
            block structure as the first input array.
        drop_axis : number or iterable, optional
            Dimensions lost by the function.
        new_axis : number or iterable, optional
            New dimensions created by the function. Note that these are applied
            after ``drop_axis`` (if present). The size of each chunk along this
            dimension will be set to 1. Please specify ``chunks`` if the individual
            chunks have a different size.
        enforce_ndim : bool, default False
            Whether to enforce at runtime that the dimensionality of the array
            produced by ``func`` actually matches that of the array returned by
            ``map_blocks``.
            If True, this will raise an error when there is a mismatch.
        token : string, optional
            The key prefix to use for the output array. If not provided, will be
            determined from the function name.
        name : string, optional
            The key name to use for the output array. Note that this fully
            specifies the output key name, and must be unique. If not provided,
            will be determined by a hash of the arguments.
        meta : array-like, optional
            The ``meta`` of the output array, when specified is expected to be an
            array of the same type and dtype of that returned when calling ``.compute()``
            on the array returned by this function. When not provided, ``meta`` will be
            inferred by applying the function to a small set of fake data, usually a
            0-d array. It's important to ensure that ``func`` can successfully complete
            computation without raising exceptions when 0-d is passed to it, providing
            ``meta`` will be required otherwise. If the output type is known beforehand
            (e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
            can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
        **kwargs :
            Other keyword arguments to pass to function. Values must be constants
            (not dask.arrays)
    
        See Also
        --------
        dask.array.map_overlap : Generalized operation with overlap between neighbors.
        dask.array.blockwise : Generalized operation with control over block alignment.
    
        Examples
        --------
        >>> import dask.array as da
        >>> x = da.arange(6, chunks=3)
    
        >>> x.map_blocks(lambda x: x * 2).compute()
        array([ 0,  2,  4,  6,  8, 10])
    
        The ``da.map_blocks`` function can also accept multiple arrays.
    
        >>> d = da.arange(5, chunks=2)
        >>> e = da.arange(5, chunks=2)
    
        >>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
        >>> f.compute()
        array([ 0,  2,  6, 12, 20])
    
        If the function changes shape of the blocks then you must provide chunks
        explicitly.
    
        >>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
    
        You have a bit of freedom in specifying chunks.  If all of the output chunk
        sizes are the same, you can provide just that chunk size as a single tuple.
    
        >>> a = da.arange(18, chunks=(6,))
        >>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
    
        If the function changes the dimension of the blocks you must specify the
        created or destroyed dimensions.
    
        >>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
        ...                  new_axis=[0, 2])
    
        If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
        add the necessary number of axes on the left.
    
        Note that ``map_blocks()`` will concatenate chunks along axes specified by
        the keyword parameter ``drop_axis`` prior to applying the function.
        This is illustrated in the figure below:
    
        .. image:: /images/map_blocks_drop_axis.png
    
        Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
        on an axis that is chunked.  In that case, it is better not to use
        ``map_blocks`` but rather
        ``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
        maintains a leaner memory footprint while it drops any axis.
    
        Map_blocks aligns blocks by block positions without regard to shape. In the
        following example we have two arrays with the same number of blocks but
        with different shape and chunk sizes.
    
        >>> x = da.arange(1000, chunks=(100,))
        >>> y = da.arange(100, chunks=(10,))
    
        The relevant attribute to match is numblocks.
    
        >>> x.numblocks
        (10,)
        >>> y.numblocks
        (10,)
    
        If these match (up to broadcasting rules) then we can map arbitrary
        functions across blocks
    
        >>> def func(a, b):
        ...     return np.array([a.max(), b.max()])
    
        >>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
        dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([ 99,   9, 199,  19, 299,  29, 399,  39, 499,  49, 599,  59, 699,
                69, 799,  79, 899,  89, 999,  99])
    
        Your block function can get information about where it is in the array by
        accepting a special ``block_info`` or ``block_id`` keyword argument.
        During computation, they will contain information about each of the input
        and output chunks (and dask arrays) relevant to each call of ``func``.
    
        >>> def func(block_info=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_info  # doctest: +SKIP
        {0: {'shape': (1000,),
             'num-chunks': (10,),
             'chunk-location': (4,),
             'array-location': [(400, 500)]},
         None: {'shape': (1000,),
                'num-chunks': (10,),
                'chunk-location': (4,),
                'array-location': [(400, 500)],
                'chunk-shape': (100,),
                'dtype': dtype('float64')}}
    
        The keys to the ``block_info`` dictionary indicate which is the input and
        output Dask array:
    
        - **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
          The dictionary key is ``0`` because that is the argument index corresponding
          to the first input Dask array.
          In cases where multiple Dask arrays have been passed as input to the function,
          you can access them with the number corresponding to the input argument,
          eg: ``block_info[1]``, ``block_info[2]``, etc.
          (Note that if you pass multiple Dask arrays as input to map_blocks,
          the arrays must match each other by having matching numbers of chunks,
          along corresponding dimensions up to broadcasting rules.)
        - **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
          and contains information about the output chunks.
          The output chunk shape and dtype may may be different than the input chunks.
    
        For each dask array, ``block_info`` describes:
    
        - ``shape``: the shape of the full Dask array,
        - ``num-chunks``: the number of chunks of the full array in each dimension,
        - ``chunk-location``: the chunk location (for example the fourth chunk over
          in the first dimension), and
        - ``array-location``: the array location within the full Dask array
          (for example the slice corresponding to ``40:50``).
    
        In addition to these, there are two extra parameters described by
        ``block_info`` for the output array (in ``block_info[None]``):
    
        - ``chunk-shape``: the output chunk shape, and
        - ``dtype``: the output dtype.
    
        These features can be combined to synthesize an array from scratch, for
        example:
    
        >>> def func(block_info=None):
        ...     loc = block_info[None]['array-location'][0]
        ...     return np.arange(loc[0], loc[1])
    
        >>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
        dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([0, 1, 2, 3, 4, 5, 6, 7])
    
        ``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
    
        >>> def func(block_id=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_id  # doctest: +SKIP
        (4, 3)
    
        You may specify the key name prefix of the resulting task in the graph with
        the optional ``token`` keyword argument.
    
        >>> x.map_blocks(lambda x: x + 1, name='increment')
        dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
    
        For functions that may not handle 0-d arrays, it's also possible to specify
        ``meta`` with an empty array matching the type of the expected result. In
        the example below, ``func`` will result in an ``IndexError`` when computing
        ``meta``:
    
        >>> rng = da.random.default_rng()
        >>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
        dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
    
        Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
        a ``dtype``:
    
        >>> import cupy  # doctest: +SKIP
        >>> rng = da.random.default_rng(cupy.random.default_rng())  # doctest: +SKIP
        >>> dt = np.float32
        >>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt))  # doctest: +SKIP
        dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
        """
        if drop_axis is None:
            drop_axis = []
    
        if not callable(func):
            msg = (
                "First argument must be callable function, not %s\n"
                "Usage:   da.map_blocks(function, x)\n"
                "   or:   da.map_blocks(function, x, y, z)"
            )
            raise TypeError(msg % type(func).__name__)
        if token:
            warnings.warn(
                "The `token=` keyword to `map_blocks` has been moved to `name=`. "
                "Please use `name=` instead as the `token=` keyword will be removed "
                "in a future release.",
                category=FutureWarning,
            )
            name = token
    
        name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
        new_axes = {}
    
        if isinstance(drop_axis, Number):
            drop_axis = [drop_axis]
        if isinstance(new_axis, Number):
            …tack
        """
        from dask.array import wrap
    
        seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
    
        if not seq:
            raise ValueError("Need array(s) to concatenate")
    
        if axis is None:
            seq = [a.flatten() for a in seq]
            axis = 0
    
        seq_metas = [meta_from_array(s) for s in seq]
        _concatenate = concatenate_lookup.dispatch(
            type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        meta = _concatenate(seq_metas, axis=axis)
    
        # Promote types to match meta
        seq = [a.astype(meta.dtype) for a in seq]
    
        # Find output array shape
        ndim = len(seq[0].shape)
        shape = tuple(
            sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
            for i in range(ndim)
        )
    
        # Drop empty arrays
        seq2 = [a for a in seq if a.size]
        if not seq2:
            seq2 = seq
    
        if axis < 0:
            axis = ndim + axis
        if axis >= ndim:
            msg = (
                "Axis must be less than than number of dimensions"
                "\nData has %d dimensions, but got axis=%d"
            )
            raise ValueError(msg % (ndim, axis))
    
        n = len(seq2)
        if n == 0:
            try:
                return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
            except TypeError:
                return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
        elif n == 1:
            return seq2[0]
    
        if not allow_unknown_chunksizes and not all(
            i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
            for i in range(ndim)
        ):
            if any(map(np.isnan, seq2[0].shape)):
                raise ValueError(
                    "Tried to concatenate arrays with unknown"
                    " shape %s.\n\nTwo solutions:\n"
                    "  1. Force concatenation pass"
                    " allow_unknown_chunksizes=True.\n"
                    "  2. Compute shapes with "
                    "[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
                )
            raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
    
        inds = [list(range(ndim)) for i in range(n)]
        for i, ind in enumerate(inds):
            ind[axis] = -(i + 1)
    
        uc_args = list(concat(zip(seq2, inds)))
        _, seq2 = unify_chunks(*uc_args, warn=False)
    
        bds = [a.chunks for a in seq2]
    
        chunks = (
            seq2[0].chunks[:axis]
            + (sum((bd[axis] for bd in bds), ()),)
            + seq2[0].chunks[axis + 1 :]
        )
    
        cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
    
        names = [a.name for a in seq2]
    
        name = "concatenate-" + tokenize(names, axis)
        keys = list(product([name], *[range(len(bd)) for bd in chunks]))
    
        values = [
            (names[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[1 : axis + 1]
            + (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[axis + 2 :]
            for key in keys
        ]
    
        dsk = dict(zip(keys, values))
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
    
        return Array(graph, name, chunks, meta=meta)
    
    
    def load_store_chunk(
        x: Any,
        out: Any,
        index: slice,
        lock: Any,
        return_stored: bool,
        load_stored: bool,
    ):
        """
        A function inserted in a Dask graph for storing a chunk.
    
        Parameters
        ----------
        x: array-like
            An array (potentially a NumPy one)
        out: array-like
            Where to store results.
        index: slice-like
            Where to store result from ``x`` in ``out``.
        lock: Lock-like or False
            Lock to use before writing to ``out``.
        return_stored: bool
            Whether to return ``out``.
        load_stored: bool
            Whether to return the array stored in ``out``.
            Ignored if ``return_stored`` is not ``True``.
    
        Returns
        -------
    
        If return_stored=True and load_stored=False
            out
        If return_stored=True and load_stored=True
            out[index]
        If return_stored=False and compute=False
            None
    
        Examples
        --------
    
        >>> a = np.ones((5, 6))
        >>> b = np.empty(a.shape)
        >>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
        """
        if lock:
            lock.acquire()
        try:
            if x is not None and x.size != 0:
                if is_arraylike(x):
                    out[index] = x
                else:
                    out[index] = np.asanyarray(x)
    
            if return_stored and load_stored:
                return out[index]
            elif return_stored and not load_stored:
                return out
            else:
                return None
        finally:
            if lock:
                lock.release()
    
    
    def store_chunk(
        x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
    ):
        return load_store_chunk(x, out, index, lock, return_stored, False)
    
    
    A = TypeVar("A", bound=ArrayLike)
    
    
    def load_chunk(out: A, index: slice, lock: Any) -> A:
        return load_store_chunk(None, out, index, lock, True, True)
    
    
    def insert_to_ooc(
        keys: list,
        chunks: tuple[tuple[int, ...], ...],
        out: ArrayLike,
        name: str,
        *,
        lock: Lock | bool = True,
        region: tuple[slice, ...] | slice | None = None,
        return_stored: bool = False,
        load_stored: bool = False,
    ) -> dict:
        """
        Creates a Dask graph for storing chunks from ``arr`` in ``out``.
    
        Parameters
        ----------
        keys: list
            Dask keys of the input array
        chunks: tuple
            Dask chunks of the input array
        out: array-like
            Where to store results to
        name: str
            First element of dask keys
        lock: Lock-like or bool, optional
            Whether to lock or with what (default is ``True``,
            which means a :class:`threading.Lock` instance).
        region: slice-like, optional
            Where in ``out`` to store ``arr``'s results
            (default is ``None``, meaning all of ``out``).
        return_stored: bool, optional
            Whether to return ``out``
            (default is ``False``, meaning ``None`` is returned).
        load_stored: bool, optional
            Whether to handling loading from ``out`` at the same time.
            Ignored if ``return_stored`` is not ``True``.
            (default is ``False``, meaning defer to ``return_stored``).
    
        Returns
        -------
        dask graph of store operation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")  # doctest: +SKIP
        """
    
        if lock is True:
            lock = Lock()
    
        slices = slices_from_chunks(chunks)
        if region:
            slices = [fuse_slice(region, slc) for slc in slices]
    
        if return_stored and load_stored:
            func = load_store_chunk
            args = (load_stored,)
        else:
            func = store_chunk  # type: ignore
            args = ()  # type: ignore
    
        dsk = {
            (name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
            for t, slc in zip(core.flatten(keys), slices)
        }
        return dsk
    
    
    def retrieve_from_ooc(
        keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
    ) -> dict[tuple, Any]:
        """
        Creates a Dask graph for loading stored ``keys`` from ``dsk``.
    
        Parameters
        ----------
        keys: Collection
            A sequence containing Dask graph keys to load
        dsk_pre: Mapping
            A Dask graph corresponding to a Dask Array before computation
        dsk_post: Mapping
            A Dask graph corresponding to a Dask Array after computation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
        >>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()})  # doctest: +SKIP
        """
        load_dsk = {
            ("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1]  # type: ignore
            for k in keys
        }
    
        return load_dsk
    
    
    def _as_dtype(a, dtype):
        if dtype is None:
            return a
        else:
            return a.astype(dtype)
    
    
    def asarray(
        a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
    ):
        """Convert the input to a dask array.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        allow_unknown_chunksizes: bool
            Allow unknown chunksizes, such as come from converting from dask
            dataframes.  Dask.array is unable to verify that chunks line up.  If
            data comes from differently aligned sources then this can cause
            unexpected results.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(
                    stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
                )
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(a, getitem=getter_inline, **kwargs)
    
    
    def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
        """Convert the input to a dask array.
    
        Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
        inline_array:
            Whether to inline the array in the resulting dask graph. For more information,
            see the documentation for ``dask.array.from_array()``.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asanyarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asanyarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(stack(a), dtype)
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asanyarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(
            a,
            chunks=a.shape,
            getitem=getter_inline,
            asarray=False,
            inline_array=inline_array,
        )
    
    
    def is_scalar_for_elemwise(arg):
        """
    
        >>> is_scalar_for_elemwise(42)
        True
        >>> is_scalar_for_elemwise('foo')
        True
        >>> is_scalar_for_elemwise(True)
        True
        >>> is_scalar_for_elemwise(np.array(42))
        True
        >>> is_scalar_for_elemwise([1, 2, 3])
        True
        >>> is_scalar_for_elemwise(np.array([1, 2, 3]))
        False
        >>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
        False
        >>> is_scalar_for_elemwise(np.dtype('i4'))
        True
        """
        # the second half of shape_condition is essentially just to ensure that
        # dask series / frame are treated as scalars in elemwise.
        maybe_shape = getattr(arg, "shape", None)
        shape_condition = not isinstance(maybe_shape, Iterable) or any(
            is_dask_collection(x) for x in maybe_shape
        )
    
        return (
            np.isscalar(arg)
            or shape_condition
            or isinstance(arg, np.dtype)
            or (isinstance(arg, np.ndarray) and arg.ndim == 0)
        )
    
    
    def broadcast_shapes(*shapes):
        """
        Determines output shape from broadcasting arrays.
    
        Parameters
        ----------
        shapes : tuples
            The shapes of the arguments.
    
        Returns
        -------
        output_shape : tuple
    
        Raises
        ------
        ValueError
            If the input shapes cannot be successfully broadcast together.
        """
        if len(shapes) == 1:
            return shapes[0]
        out = []
        for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
            if np.isnan(sizes).any():
                dim = np.nan
            else:
                dim = 0 if 0 in sizes else np.max(sizes).item()
            if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
                raise ValueError(
                    "operands could not be broadcast together with "
                    "shapes {}".format(" ".join(map(str, shapes)))
                )
            out.append(dim)
        return tuple(reversed(out))
    
    
    def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
        """Apply an elementwise ufunc-like function blockwise across arguments.
    
        Like numpy ufuncs, broadcasting rules are respected.
    
        Parameters
        ----------
        op : callable
            The function to apply. Should be numpy ufunc-like in the parameters
            that it accepts.
        *args : Any
            Arguments to pass to `op`. Non-dask array-like objects are first
            converted to dask arrays, then all arrays are broadcast together before
            applying the function blockwise across all arguments. Any scalar
            arguments are passed as-is following normal numpy ufunc behavior.
        out : dask array, optional
            If out is a dask.array then this overwrites the contents of that array
            with the result.
        where : array_like, optional
            An optional boolean mask marking locations where the ufunc should be
            applied. Can be a scalar, dask array, or any other array-like object.
            Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
            for more information.
        dtype : dtype, optional
            If provided, overrides the output array dtype.
        name : str, optional
            A unique key name to use when building the backing dask graph. If not
            provided, one will be automatically generated based on the input
            arguments.
    
        Examples
        --------
        >>> elemwise(add, x, y)  # doctest: +SKIP
        >>> elemwise(sin, x)  # doctest: +SKIP
        >>> elemwise(sin, x, out=dask_array)  # doctest: +SKIP
    
        See Also
        --------
        blockwise
        """
        if kwargs:
            raise TypeError(
                f"{op.__name__} does not take the following keyword arguments "
                f"{sorted(kwargs)}"
            )
    
        out = _elemwise_normalize_out(out)
        where = _elemwise_normalize_where(where)
        args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
    
        shapes = []
        for arg in args:
            shape = getattr(arg, "shape", ())
            if any(is_dask_collection(x) for x in shape):
                # Want to exclude Delayed shapes and dd.Scalar
                shape = ()
            shapes.append(shape)
        if isinstance(where, Array):
            shapes.append(where.shape)
        if isinstance(out, Array):
            shapes.append(out.shape)
    
        shapes = [s if isinstance(s, Iterable) else () for s in shapes]
        out_ndim = len(
            broadcast_shapes(*shapes)
        )  # Raises ValueError if dimensions mismatch
        expr_inds = tuple(range(out_ndim))[::-1]
    
        if dtype is not None:
            need_enforce_dtype = True
        else:
            # We follow NumPy's rules for dtype promotion, which special cases
            # scalars and 0d ndarrays (which it considers equivalent) by using
            # their values to compute the result dtype:
            # https://github.com/numpy/numpy/issues/6240
            # We don't inspect the values of 0d dask arrays, because these could
            # hold potentially very expensive calculations. Instead, we treat
            # them just like other arrays, and if necessary cast the result of op
            # to match.
            vals = [
                (
                    np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
                    if not is_scalar_for_elemwise(a)
                    else a
                )
                for a in args
            ]
            try:
                dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
            except Exception:
                return NotImplemented
            need_enforce_dtype = any(
                not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
            )
    
        if not name:
            name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
    
        blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
    
        if where is not True:
            blockwise_kwargs["elemwise_where_function"] = op
            op = _elemwise_handle_where
            args.extend([where, out])
    
        if need_enforce_dtype:
            blockwise_kwargs["enforce_dtype"] = dtype
            blockwise_kwargs["enforce_dtype_function"] = op
            op = _enforce_dtype
    
        result = blockwise(
            op,
            expr_inds,
            *concat(
                (a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
                for a in args
            ),
            **blockwise_kwargs,
        )
    
        return handle_out(out, result)
    
    
    def _elemwise_normalize_where(where):
        if where is True:
            return True
        elif where is False or where is None:
            return False
        return asarray(where)
    
    
    def _elemwise_handle_where(*args, **kwargs):
        function = kwargs.pop("elemwise_where_function")
        *args, where, out = args
        if hasattr(out, "copy"):
            out = out.copy()
        return function(*args, where=where, out=out, **kwargs)
    
    
    def _elemwise_normalize_out(out):
        if isinstance(out, tuple):
            if len(out) == 1:
                out = out[0]
            elif len(out) > 1:
                raise NotImplementedError("The out parameter is not fully supported")
            else:
                out = None
        if not (out is None or isinstance(out, Array)):
            raise NotImplementedError(
                f"The out parameter is not fully supported."
                f" Received type {type(out).__name__}, expected Dask Array"
            )
        return out
    
    
    def handle_out(out, result):
        """Handle out parameters
    
        If out is a dask.array then this overwrites the contents of that array with
        the result
        """
        out = _elemwise_normalize_out(out)
        if isinstance(out, Array):
            if out.shape != result.shape:
                raise ValueError(
                    "Mismatched shapes between result and out parameter. "
                    "out=%s, result=%s" % (str(out.shape), str(result.shape))
                )
            out._chunks = result.chunks
            out.dask = result.dask
            out._meta = result._meta
            out._name = result.name
            return out
        else:
            return result
    
    
    def _enforce_dtype(*args, **kwargs):
        """Calls a function and converts its result to the given dtype.
    
        The parameters have deliberately been given unwieldy names to avoid
        clashes with keyword arguments consumed by blockwise
    
        A dtype of `object` is treated as a special case and not enforced,
        because it is used as a dummy value in some places when the result will
        not be a block in an Array.
    
        Parameters
        ----------
        enforce_dtype : dtype
            Result dtype
        enforce_dtype_function : callable
            The wrapped function, which will be passed the remaining arguments
        """
        dtype = kwargs.pop("enforce_dtype")
        function = kwargs.pop("enforce_dtype_function")
    
        result = function(*args, **kwargs)
        if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
            if not np.can_cast(result, dtype, casting="same_kind"):
                raise ValueError(
                    "Inferred dtype from function %r was %r "
                    "but got %r, which can't be cast using "
                    "casting='same_kind'"
                    % (funcname(function), str(dtype), str(result.dtype))
                )
            if np.isscalar(result):
                # scalar astype method doesn't take the keyword arguments, so
                # have to convert via 0-dimensional array and back.
                result = result.astype(dtype)
            else:
                try:
                    result = result.astype(dtype, copy=False)
                except TypeError:
                    # Missing copy kwarg
                    result = result.astype(dtype)
        return result
    
    
    def broadcast_to(x, shape, chunks=None, meta=None):
        """Broadcast an array to a new shape.
    
        Parameters
        ----------
        x : array_like
            The array to broadcast.
        shape : tuple
            The shape of the desired array.
        chunks : tuple, optional
            If provided, then the result will use these chunks instead of the same
            chunks as the source array. Setting chunks explicitly as part of
            broadcast_to is more efficient than rechunking afterwards. Chunks are
            only allowed to differ from the original shape along dimensions that
            are new on the result or have size 1 the input array.
        meta : empty ndarray
            empty ndarray created with same NumPy backend, ndim and dtype as the
            Dask Array being created (overrides dtype)
    
        Returns
        -------
        broadcast : dask array
    
        See Also
        --------
        :func:`numpy.broadcast_to`
        """
        x = asarray(x)
        shape = tuple(shape)
    
        if meta is None:
            meta = meta_from_array(x)
    
        if x.shape == shape and (chunks is None or chunks == x.chunks):
            return x
    
        ndim_new = len(shape) - x.ndim
        if ndim_new < 0 or any(
            new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
        ):
            raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
    
        if chunks is None:
            chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
                bd if old > 1 else (new,)
                for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
            )
        else:
            chunks = normalize_chunks(
                chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
            )
            for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
                if old_bd != new_bd and old_bd != (1,):
                    raise ValueError(
                        "cannot broadcast chunks %s to chunks %s: "
                        "new chunks must either be along a new "
                        "dimension or a dimension of size 1" % (x.chunks, chunks)
                    )
    
        name = "broadcast_to-" + tokenize(x, shape, chunks)
        dsk = {}
    
        enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
        for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
            old_index = tuple(
                0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
            )
            old_key = (x.name,) + old_index
            new_key = (name,) + new_index
            dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
    
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
        return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
    
    
    @derived_from(np)
    def broadcast_arrays(*args, subok=False):
        subok = bool(subok)
    
        to_array = asanyarray if subok else asarray
        args = tuple(to_array(e) for e in args)
    
        # Unify uneven chunking
        inds = [list(reversed(range(x.ndim))) for x in args]
        uc_args = concat(zip(args, inds))
        _, args = unify_chunks(*uc_args, warn=False)
    
        shape = broadcast_shapes(*(e.shape for e in args))
        chunks = broadcast_chunks(*(e.chunks for e in args))
    
        if NUMPY_GE_200:
            result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
        else:
            result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
    
        return result
    
    
    def offset_func(func, offset, *args):
        """Offsets inputs by offset
    
        >>> double = lambda x: x * 2
        >>> f = offset_func(double, (10,))
        >>> f(1)
        22
        >>> f(300)
        620
        """
    
        def _offset(*args):
            args2 = list(map(add, args, offset))
            return func(*args2)
    
        with contextlib.suppress(Exception):
            _offset.__name__ = "offset_" + func.__name__
    
        return _offset
    
    
    def chunks_from_arrays(arrays):
        """Chunks tuple from nested list of arrays
    
        >>> x = np.array([1, 2])
        >>> chunks_from_arrays([x, x])
        ((2, 2),)
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x], [x]])
        ((1, 1), (2,))
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x, x]])
        ((1,), (2, 2))
    
        >>> chunks_from_arrays([1, 1])
        ((1, 1),)
        """
        if not arrays:
            return ()
        result = []
        dim = 0
    
        def shape(x):
            try:
                return x.shape if x.shape else (1,)
            except AttributeError:
                return (1,)
    
        while isinstance(arrays, (list, tuple)):
>           result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E           IndexError: tuple index out of range

../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: IndexError

Check warning on line 0 in distributed.tests.test_client

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

1 out of 12 runs failed: test_call_stack_future (distributed.tests.test_client)

artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 1s]
Raw output
KeyError: 'slowinc-1ea6f50d354d90f0a53d5f5250e2c25f'
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:61467', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:61468', name: 0, status: closed, stored: 0, running: 1/4, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:61471', name: 1, status: closed, stored: 0, running: 1/4, ready: 0, comm: 0, waiting: 0>

    @gen_cluster([("127.0.0.1", 4)] * 2, client=True)
    async def test_call_stack_future(c, s, a, b):
        x = c.submit(slowdec, 1, delay=0.5)
        future = c.submit(slowinc, 1, delay=0.5)
        await asyncio.sleep(0.1)
>       results = await asyncio.gather(
            c.call_stack(future), c.call_stack(keys=[future.key])
        )

distributed\tests\test_client.py:5467: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:834: in _handle_comm
    result = await result
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import dataclasses
    import heapq
    import inspect
    import itertools
    import json
    import logging
    import math
    import operator
    import os
    import pickle
    import random
    import textwrap
    import uuid
    import warnings
    import weakref
    from abc import abstractmethod
    from collections import defaultdict, deque
    from collections.abc import (
        Callable,
        Collection,
        Container,
        Hashable,
        Iterable,
        Iterator,
        Mapping,
        Sequence,
        Set,
    )
    from contextlib import suppress
    from functools import partial
    from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, cast, overload
    
    import psutil
    import tornado.web
    from sortedcontainers import SortedDict, SortedSet
    from tlz import (
        concat,
        first,
        groupby,
        merge,
        merge_sorted,
        merge_with,
        partition,
        pluck,
        second,
        take,
        valmap,
    )
    from tornado.ioloop import IOLoop
    
    import dask
    import dask.utils
    from dask._task_spec import DependenciesMapping, GraphNode, convert_legacy_graph
    from dask.base import TokenizationError, normalize_token, tokenize
    from dask.core import istask, validate_key
    from dask.typing import Key, no_default
    from dask.utils import (
        _deprecated,
        _deprecated_kwarg,
        ensure_dict,
        format_bytes,
        format_time,
        key_split,
        parse_bytes,
        parse_timedelta,
        tmpfile,
    )
    from dask.widgets import get_template
    
    from distributed import cluster_dump, preloading, profile
    from distributed import versions as version_module
    from distributed._asyncio import RLock
    from distributed._stories import scheduler_story
    from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
    from distributed.batched import BatchedSend
    from distributed.broker import Broker
    from distributed.client import SourceCode
    from distributed.collections import HeapSet
    from distributed.comm import (
        Comm,
        CommClosedError,
        get_address_host,
        normalize_address,
        resolve_address,
        unparse_host_port,
    )
    from distributed.comm.addressing import addresses_from_user_args
    from distributed.compatibility import PeriodicCallback
    from distributed.core import (
        ErrorMessage,
        OKMessage,
        Status,
        clean_exception,
        error_message,
        rpc,
        send_recv,
    )
    from distributed.diagnostics.memory_sampler import MemorySamplerExtension
    from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name
    from distributed.event import EventExtension
    from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
    from distributed.http import get_handlers
    from distributed.metrics import monotonic, time
    from distributed.multi_lock import MultiLockExtension
    from distributed.node import ServerNode
    from distributed.proctitle import setproctitle
    from distributed.protocol import deserialize
    from distributed.protocol.pickle import dumps, loads
    from distributed.protocol.serialize import Serialized, ToPickle, serialize
    from distributed.publish import PublishExtension
    from distributed.pubsub import PubSubSchedulerExtension
    from distributed.queues import QueueExtension
    from distributed.recreate_tasks import ReplayTaskScheduler
    from distributed.security import Security
    from distributed.semaphore import SemaphoreExtension
    from distributed.shuffle import ShuffleSchedulerPlugin
    from distributed.spans import SpanMetadata, SpansSchedulerExtension
    from distributed.stealing import WorkStealing
    from distributed.utils import (
        All,
        Deadline,
        TimeoutError,
        format_dashboard_link,
        get_fileno_limit,
        key_split_group,
        log_errors,
        offload,
        recursive_to_dict,
        wait_for,
    )
    from distributed.utils_comm import (
        gather_from_workers,
        retry_operation,
        scatter_to_workers,
    )
    from distributed.variable import VariableExtension
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        # TODO import from typing (requires Python >=3.11)
        from typing_extensions import Self, TypeAlias
    
        from dask.highlevelgraph import HighLevelGraph
    
    # Not to be confused with distributed.worker_state_machine.TaskStateState
    TaskStateState: TypeAlias = Literal[
        "released",
        "waiting",
        "no-worker",
        "queued",
        "processing",
        "memory",
        "erred",
        "forgotten",
    ]
    
    ALL_TASK_STATES: Set[TaskStateState] = set(TaskStateState.__args__)  # type: ignore
    
    # {task key -> finish state}
    # Not to be confused with distributed.worker_state_machine.Recs
    Recs: TypeAlias = dict[Key, TaskStateState]
    # {client or worker address: [{op: <key>, ...}, ...]}
    Msgs: TypeAlias = dict[str, list[dict[str, Any]]]
    # (recommendations, client messages, worker messages)
    RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs]
    
    T_runspec: TypeAlias = GraphNode
    
    logger = logging.getLogger(__name__)
    LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
    DEFAULT_DATA_SIZE = parse_bytes(
        dask.config.get("distributed.scheduler.default-data-size")
    )
    STIMULUS_ID_UNSET = "<stimulus_id unset>"
    
    DEFAULT_EXTENSIONS = {
        "multi_locks": MultiLockExtension,
        "publish": PublishExtension,
        "replay-tasks": ReplayTaskScheduler,
        "queues": QueueExtension,
        "variables": VariableExtension,
        "pubsub": PubSubSchedulerExtension,
        "semaphores": SemaphoreExtension,
        "events": EventExtension,
        "amm": ActiveMemoryManagerExtension,
        "memory_sampler": MemorySamplerExtension,
        "shuffle": ShuffleSchedulerPlugin,
        "spans": SpansSchedulerExtension,
        "stealing": WorkStealing,
    }
    
    
    class ClientState:
        """A simple object holding information about a client."""
    
        #: A unique identifier for this client. This is generally an opaque
        #: string generated by the client itself.
        client_key: str
    
        #: Cached hash of :attr:`~ClientState.client_key`
        _hash: int
    
        #: A set of tasks this client wants to be kept in memory, so that it can download
        #: its result when desired. This is the reverse mapping of
        #: :class:`TaskState.who_wants`. Tasks are typically removed from this set when the
        #: corresponding object in the client's space (for example a ``Future`` or a Dask
        #: collection) gets garbage-collected.
        wants_what: set[TaskState]
    
        #: The last time we received a heartbeat from this client, in local scheduler time.
        last_seen: float
    
        #: Output of :func:`distributed.versions.get_versions` on the client
        versions: dict[str, Any]
    
        __slots__ = tuple(__annotations__)
    
        def __init__(self, client: str, *, versions: dict[str, Any] | None = None):
            self.client_key = client
            self._hash = hash(client)
            self.wants_what = set()
            self.last_seen = time()
            self.versions = versions or {}
    
        def __hash__(self) -> int:
            return self._hash
    
        def __eq__(self, other: object) -> bool:
            if not isinstance(other, ClientState):
                return False
            return self.client_key == other.client_key
    
        def __repr__(self) -> str:
            return f"<Client {self.client_key!r}>"
    
        def __str__(self) -> str:
            return self.client_key
    
        def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict:
            """Dictionary representation for debugging purposes.
            Not type stable and not intended for roundtrips.
    
            See also
            --------
            Client.dump_cluster_state
            distributed.utils.recursive_to_dict
            TaskState._to_dict
            """
            return recursive_to_dict(
                self,
                exclude=set(exclude) | {"versions"},  # type: ignore
                members=True,
            )
    
    
    class MemoryState:
        """Memory readings on a worker or on the whole cluster.
    
        See :doc:`worker-memory`.
    
        Attributes / properties:
    
        managed_total
            Sum of the output of sizeof() for all dask keys held by the worker in memory,
            plus number of bytes spilled to disk
    
        managed
            Sum of the output of sizeof() for the dask keys held in RAM. Note that this may
            be inaccurate, which may cause inaccurate unmanaged memory (see below).
    
        spilled
            Number of bytes  for the dask keys spilled to the hard drive.
            Note that this is the size on disk; size in memory may be different due to
            compression and inaccuracies in sizeof(). In other words, given the same keys,
            'managed' will change depending on the keys being in memory or spilled.
    
        process
            Total RSS memory measured by the OS on the worker process.
            This is always exactly equal to managed + unmanaged.
    
        unmanaged
            process - managed. This is the sum of
    
            - Python interpreter and modules
            - global variables
            - memory temporarily allocated by the dask tasks that are currently running
            - memory fragmentation
            - memory leaks
            - memory not yet garbage collected
            - memory not yet free()'d by the Python memory manager to the OS
    
        unmanaged_old
            Minimum of the 'unmanaged' measures over the last
            ``distributed.memory.recent-to-old-time`` seconds
    
        unmanaged_recent
            unmanaged - unmanaged_old; in other words process memory that has been recently
            allocated but is not accounted for by dask; hopefully it's mostly a temporary
            spike.
    
        optimistic
            managed + unmanaged_old; in other words the memory held long-term by
            the process under the hopeful assumption that all unmanaged_recent memory is a
            temporary spike
        """
    
        process: int
        unmanaged_old: int
        managed: int
        spilled: int
    
        __slots__ = tuple(__annotations__)
    
        def __init__(
            self,
            *,
            process: int,
            unmanaged_old: int,
            managed: int,
            spilled: int,
        ):
            # Some data arrives with the heartbeat, some other arrives in realtime as the
            # tasks progress. Also, sizeof() is not guaranteed to return correct results.
            # This can cause glitches where a partial measure is larger than the whole, so
            # we need to force all numbers to add up exactly by definition.
            self.process = process
            self.managed = min(self.process, managed)
            self.spilled = spilled
            # Subtractions between unsigned ints guaranteed by construction to be >= 0
            self.unmanaged_old = min(unmanaged_old, process - self.managed)
    
        @staticmethod
        def sum(*infos: MemoryState) -> MemoryState:
            process = 0
            unmanaged_old = 0
            managed = 0
            spilled = 0
            for ms in infos:
                process += ms.process
                unmanaged_old += ms.unmanaged_old
                spilled += ms.spilled
                managed += ms.managed
            return MemoryState(
                process=process,
                unmanaged_old=unmanaged_old,
                managed=managed,
                spilled=spilled,
            )
    
        @property
        def managed_total(self) -> int:
            return self.managed + self.spilled
    
        @property
        def unmanaged(self) -> int:
            # This is never negative thanks to __init__
            return self.process - self.managed
    
        @property
        def unmanaged_recent(self) -> int:
            # This is never negative thanks to __init__
            return self.process - self.managed - self.unmanaged_old
    
        @property
        def optimistic(self) -> int:
            return self.managed + self.unmanaged_old
    
        @property
        def managed_in_memory(self) -> int:
            warnings.warn("managed_in_memory has been renamed to managed", FutureWarning)
            return self.managed
    
        @property
        def managed_spilled(self) -> int:
            warnings.warn("managed_spilled has been renamed to spilled", FutureWarning)
            return self.spilled
    
        def __repr__(self) -> str:
            return (
                f"Process memory (RSS)  : {format_bytes(self.process)}\n"
                f"  - managed by Dask   : {format_bytes(self.managed)}\n"
                f"  - unmanaged (old)   : {format_bytes(self.unmanaged_old)}\n"
                f"  - unmanaged (recent): {format_bytes(self.unmanaged_recent)}\n"
                f"Spilled to disk       : {format_bytes(self.spilled)}\n"
            )
    
        def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
            """Dictionary representation for debugging purposes.
    
            See also
            --------
            Client.dump_cluster_state
            distributed.utils.recursive_to_dict
            """
            return {
                k: getattr(self, k)
                for k in dir(self)
                if not k.startswith("_")
                and k not in {"sum", "managed_in_memory", "managed_spilled"}
            }
    
    
    class WorkerState:
        """A simple object holding information about a worker.
    
        Not to be confused with :class:`distributed.worker_state_machine.WorkerState`.
        """
    
        #: This worker's unique key. This can be its connected address
        #: (such as ``"tcp://127.0.0.1:8891"``) or an alias (such as ``"alice"``).
        address: str
    
        pid: int
        name: Hashable
    
        #: The number of CPU threads made available on this worker
        nthreads: int
    
        #: Memory available to the worker, in bytes
        memory_limit: int
    
        local_directory: str
        services: dict[str, int]
    
        #: Output of :meth:`distributed.versions.get_versions` on the worker
        versions: dict[str, Any]
    
        #: Address of the associated :class:`~distributed.nanny.Nanny`, if present
        nanny: str | None
    
        #: Read-only worker status, synced one way from the remote Worker object
        status: Status
    
        #: Cached hash of :attr:`~WorkerState.server_id`
        _hash: int
    
        #: The total memory size, in bytes, used by the tasks this worker holds in memory
        #: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`).
        nbytes: int
    
        #: Worker memory unknown to the worker, in bytes, which has been there for more than
        #: 30 seconds. See :class:`MemoryState`.
        _memory_unmanaged_old: int
    
        #: History of the last 30 seconds' worth of unmanaged memory. Used to differentiate
        #: between "old" and "new" unmanaged memory.
        #: Format: ``[(timestamp, bytes), (timestamp, bytes), ...]``
        _memory_unmanaged_history: deque[tuple[float, int]]
    
        metrics: dict[str, Any]
    
        #: The last time we received a heartbeat from this worker, in local scheduler time.
        last_seen: float
    
        time_delay: float
        bandwidth: float
    
        #: A set of all TaskStates on this worker that are actors. This only includes those
        #: actors whose state actually lives on this worker, not actors to which this worker
        #: has a reference.
        actors: set[TaskState]
    
        #: Underlying data of :meth:`WorkerState.has_what`
        _has_what: dict[TaskState, None]
    
        #: A set of tasks that have been submitted to this worker. Multiple tasks may be
        # submitted to a worker in advance and the worker will run them eventually,
        # depending on its execution resources (but see :doc:`work-stealing`).
        #:
        #: All the tasks here are in the "processing" state.
        #: This attribute is kept in sync with :attr:`TaskState.processing_on`.
        processing: set[TaskState]
    
        #: Running tasks that invoked :func:`distributed.secede`
        long_running: set[TaskState]
    
        #: A dictionary of tasks that are currently being run on this worker.
        #: Each task state is associated with the duration in seconds which the task has
        #: been running.
        executing: dict[TaskState, float]
    
        #: The available resources on this worker, e.g. ``{"GPU": 2}``.
        #: These are abstract quantities that constrain certain tasks from running at the
        #: same time on this worker.
        resources: dict[str, float]
    
        #: The sum of each resource used by all tasks allocated to this worker.
        #: The numbers in this dictionary can only be less or equal than those in this
        #: worker's :attr:`~WorkerState.resources`.
        used_resources: dict[str, float]
    
        #: Arbitrary additional metadata to be added to :meth:`~WorkerState.identity`
        extra: dict[str, Any]
    
        # The unique server ID this WorkerState is referencing
        server_id: str
    
        # Reference to scheduler task_groups
        scheduler_ref: weakref.ref[SchedulerState] | None
        task_prefix_count: defaultdict[str, int]
        _network_occ: float
        _occupancy_cache: float | None
    
        #: Keys that may need to be fetched to this worker, and the number of tasks that need them.
        #: All tasks are currently in `memory` on a worker other than this one.
        #: Much like `processing`, this does not exactly reflect worker state:
        #: keys here may be queued to fetch, in flight, or already in memory
        #: on the worker.
        needs_what: dict[TaskState, int]
    
        __slots__ = tuple(__annotations__)
    
        def __init__(
            self,
            *,
            address: str,
            status: Status,
            pid: int,
            name: object,
            nthreads: int = 0,
            memory_limit: int,
            local_directory: str,
            nanny: str | None,
            server_id: str,
            services: dict[str, int] | None = None,
            versions: dict[str, Any] | None = None,
            extra: dict[str, Any] | None = None,
            scheduler: SchedulerState | None = None,
        ):
            self.server_id = server_id
            self.address = address
            self.pid = pid
            self.name = name
            self.nthreads = nthreads
            self.memory_limit = memory_limit
            self.local_directory = local_directory
            self.services = services or {}
            self.versions = versions or {}
            self.nanny = nanny
            self.status = status
            self._hash = hash(self.server_id)
            self.nbytes = 0
            self._memory_unmanaged_old = 0
            self._memory_unmanaged_history = deque()
            self.metrics = {}
            self.last_seen = time()
            self.time_delay = 0
            self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
            self.actors = set()
            self._has_what = {}
            self.processing = set()
            self.long_running = set()
            self.executing = {}
            self.resources = {}
            self.used_resources = {}
            self.extra = extra or {}
            self.scheduler_ref = weakref.ref(scheduler) if scheduler else None
            self.task_prefix_count = defaultdict(int)
            self.needs_what = {}
            self._network_occ = 0
            self._occupancy_cache = None
    
        def __hash__(self) -> int:
            return self._hash
    
        def __eq__(self, other: object) -> bool:
            return self is other or (
                isinstance(other, WorkerState) and other.server_id == self.server_id
            )
    
        @property
        def has_what(self) -> Set[TaskState]:
            """An insertion-sorted set-like of tasks which currently reside on this worker.
            All the tasks here are in the "memory" state.
            This is the reverse mapping of :attr:`TaskState.who_has`.
    
            This is a read-only public accessor. The data is implemented as a dict without
            values, because rebalance() relies on dicts being insertion-sorted.
            """
            return self._has_what.keys()
    
        @property
        def host(self) -> str:
            return get_address_host(self.address)
    
        @property
        def memory(self) -> MemoryState:
            """Polished memory metrics for the worker.
    
            **Design note on managed memory**
    
            There are two measures available for managed memory:
    
            - ``self.nbytes``
            - ``self.metrics["managed_bytes"]``
    
            At rest, the two numbers must be identical. However, ``self.nbytes`` is
            immediately updated through the batched comms as soon as each task lands in
            memory on the worker; ``self.metrics["managed_bytes"]`` instead is updated by
            the heartbeat, which can lag several seconds behind.
    
            Below we are mixing likely newer managed memory info from ``self.nbytes`` with
            process and spilled memory from the heartbeat. This is deliberate, so that
            managed memory total is updated more frequently.
    
            Managed memory directly and immediately contributes to optimistic memory, which
            is in turn used in Active Memory Manager heuristics (at the moment of writing;
            more uses will likely be added in the future). So it's important to have it
            up to date; much more than it is for process memory.
    
            Having up-to-date managed memory info as soon as the scheduler learns about
            task completion also substantially simplifies unit tests.
    
            The flip side of this design is that it may cause some noise in the
            unmanaged_recent measure. e.g.:
    
            1. Delete 100MB of managed data
            2. The updated managed memory reaches the scheduler faster than the
               updated process memory
            3. There's a blip where the scheduler thinks that there's a sudden 100MB
               increase in unmanaged_recent, since process memory hasn't changed but managed
               memory has decreased by 100MB
            4. When the heartbeat arrives, process memory goes down and so does the
               unmanaged_recent.
    
            This is OK - one of the main reasons for the unmanaged_recent / unmanaged_old
            split is exactly to concentrate all the noise in unmanaged_recent and exclude it
            from optimistic memory, which is used for heuristics.
    
            Something that is less OK, but also less frequent, is that the sudden deletion
            of spilled keys will cause a negative blip in managed memory:
    
            1. Delete 100MB of spilled data
            2. The updated managed memory *total* reaches the scheduler faster than the
               updated spilled portion
            3. This causes the managed memory to temporarily plummet and be replaced by
               unmanaged_recent, while spilled memory remains unaltered
            4. When the heartbeat arrives, managed goes back up, unmanaged_recent
               goes back down, and spilled goes down by 100MB as it should have to
               begin with.
    
            :issue:`6002` will let us solve this.
            """
            return MemoryState(
                process=self.metrics["memory"],
                managed=max(0, self.nbytes - self.metrics["spilled_bytes"]["memory"]),
                spilled=self.metrics["spilled_bytes"]["disk"],
                unmanaged_old=self._memory_unmanaged_old,
            )
    
        def clean(self) -> WorkerState:
            """Return a version of this object that is appropriate for serialization"""
            ws = WorkerState(
                address=self.address,
                status=self.status,
                pid=self.pid,
                name=self.name,
                nthreads=self.nthreads,
                memory_limit=self.memory_limit,
                local_directory=self.local_directory,
                services=self.services,
                nanny=self.nanny,
                extra=self.extra,
                server_id=self.server_id,
            )
            ws._occupancy_cache = self.occupancy
    
            ws.executing = {ts.key: duration for ts, duration in self.executing.items()}  # type: ignore
            return ws
    
        def __repr__(self) -> str:
            name = f", name: {self.name}" if self.name != self.address else ""
            return (
                f"<WorkerState {self.address!r}{name}, "
                f"status: {self.status.name}, "
                f"memory: {len(self.has_what)}, "
                f"processing: {len(self.processing)}>"
            )
    
        def _repr_html_(self) -> str:
            return get_template("worker_state.html.j2").render(
                address=self.address,
                name=self.name,
                status=self.status.name,
                has_what=self.has_what,
                processing=self.processing,
            )
    
        def identity(self) -> dict[str, Any]:
            return {
                "type": "Worker",
                "id": self.name,
                "host": self.host,
                "resources": self.resources,
                "local_directory": self.local_directory,
                "name": self.name,
                "nthreads": self.nthreads,
                "memory_limit": self.memory_limit,
                "last_seen": self.last_seen,
                "services": self.services,
                "metrics": self.metrics,
                "status": self.status.name,
                "nanny": self.nanny,
                **self.extra,
            }
    
        def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
            """Dictionary representation for debugging purposes.
            Not type stable and not intended for roundtrips.
    
            See also
            --------
            Client.dump_cluster_state
            distributed.utils.recursive_to_dict
            TaskState._to_dict
            """
            return recursive_to_dict(
                self,
                exclude=set(exclude) | {"versions"},  # type: ignore
                members=True,
            )
    
        @property
        def scheduler(self) -> SchedulerState:
            assert self.scheduler_ref
            s = self.scheduler_ref()
            assert s
            return s
    
        def add_to_processing(self, ts: TaskState) -> None:
            """Assign a task to this worker for compute."""
            if self.scheduler.validate:
                assert ts not in self.processing
    
            tp = ts.prefix
            self.task_prefix_count[tp.name] += 1
            self.scheduler._task_prefix_count_global[tp.name] += 1
            self.processing.add(ts)
            for dts in ts.dependencies:
                assert dts.who_has
                if self not in dts.who_has:
                    self._inc_needs_replica(dts)
    
        def add_to_long_running(self, ts: TaskState) -> None:
            if self.scheduler.validate:
                assert ts in self.processing
                assert ts not in self.long_running
    
            self._remove_from_task_prefix_count(ts)
            # Cannot remove from processing since we're using this for things like
            # idleness detection. Idle workers are typically targeted for
            # downscaling but we should not downscale workers with long running
            # tasks
            self.long_running.add(ts)
    
        def remove_from_processing(self, ts: TaskState) -> None:
            """Remove a task from a workers processing"""
            if self.scheduler.validate:
                assert ts in self.processing
    
            if ts in self.long_running:
                self.long_running.discard(ts)
            else:
                self._remove_from_task_prefix_count(ts)
            self.processing.remove(ts)
            for dts in ts.dependencies:
                if dts in self.needs_what:
                    self._dec_needs_replica(dts)
    
        def _remove_from_task_prefix_count(self, ts: TaskState) -> None:
            prefix_name = ts.prefix.name
            count = self.task_prefix_count[prefix_name] - 1
            tp_count = self.task_prefix_count
            tp_count_global = self.scheduler._task_prefix_count_global
            if count:
                tp_count[prefix_name] = count
            else:
                del tp_count[prefix_name]
    
            count = tp_count_global[prefix_name] - 1
            if count:
                tp_count_global[prefix_name] = count
            else:
                del tp_count_global[prefix_name]
    
        def remove_replica(self, ts: TaskState) -> None:
            """The worker no longer has a task in memory"""
            if self.scheduler.validate:
                assert ts.who_has
                assert self in ts.who_has
                assert ts in self.has_what
                assert ts not in self.needs_what
    
            self.nbytes -= ts.get_nbytes()
            del self._has_what[ts]
            ts.who_has.remove(self)  # type: ignore
            if not ts.who_has:
                ts.who_has = None
    
        def _inc_needs_replica(self, ts: TaskState) -> None:
            """Assign a task fetch to this worker and update network occupancies"""
            if self.scheduler.validate:
                assert ts.who_has
                assert self not in ts.who_has
                assert ts not in self.has_what
            if ts not in self.needs_what:
                self.needs_what[ts] = 1
                nbyte…              # See definition of recipients above
                        heapq.heapreplace(
                            recipients,
                            (rec_bytes_max, rec_bytes_min, id(rec_ws), rec_ws),
                        )
                    else:
                        heapq.heappop(recipients)
    
                    # Move to next sender with the most data to lose.
                    # It may or may not be the same sender again.
                    break
    
                else:  # for ts in ts_iter
                    # Exhausted tasks on this sender
                    heapq.heappop(senders)
    
            return msgs
    
        async def _rebalance_move_data(
            self, msgs: list[tuple[WorkerState, WorkerState, TaskState]], stimulus_id: str
        ) -> dict:
            """Perform the actual transfer of data across the network in rebalance().
            Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
    
            - sender worker
            - recipient worker
            - task to be transferred
    
            FIXME this method is not robust when the cluster is not idle.
            """
            # {recipient address: {key: [sender address, ...]}}
            to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict(
                lambda: defaultdict(list)
            )
            for snd_ws, rec_ws, ts in msgs:
                to_recipients[rec_ws.address][ts.key].append(snd_ws.address)
            failed_keys_by_recipient = dict(
                zip(
                    to_recipients,
                    await asyncio.gather(
                        *(
                            # Note: this never raises exceptions
                            self.gather_on_worker(w, who_has)
                            for w, who_has in to_recipients.items()
                        )
                    ),
                )
            )
    
            to_senders = defaultdict(list)
            for snd_ws, rec_ws, ts in msgs:
                if ts.key not in failed_keys_by_recipient[rec_ws.address]:
                    to_senders[snd_ws.address].append(ts.key)
    
            # Note: this never raises exceptions
            await asyncio.gather(
                *(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())
            )
    
            for r, v in to_recipients.items():
                self.log_event(r, {"action": "rebalance", "who_has": v})
            self.log_event(
                "all",
                {
                    "action": "rebalance",
                    "senders": valmap(len, to_senders),
                    "recipients": valmap(len, to_recipients),
                    "moved_keys": len(msgs),
                },
            )
    
            missing_keys = {k for r in failed_keys_by_recipient.values() for k in r}
            if missing_keys:
                return {"status": "partial-fail", "keys": list(missing_keys)}
            else:
                return {"status": "OK"}
    
        async def replicate(
            self,
            comm=None,
            keys=None,
            n=None,
            workers=None,
            branching_factor=2,
            delete=True,
            stimulus_id=None,
        ):
            """Replicate data throughout cluster
    
            This performs a tree copy of the data throughout the network
            individually on each piece of data.
    
            Parameters
            ----------
            keys: Iterable
                list of keys to replicate
            n: int
                Number of replications we expect to see within the cluster
            branching_factor: int, optional
                The number of workers that can copy data in each generation.
                The larger the branching factor, the more data we copy in
                a single step, but the more a given worker risks being
                swamped by data requests.
    
            See also
            --------
            Scheduler.rebalance
            """
            stimulus_id = stimulus_id or f"replicate-{time()}"
            assert branching_factor > 0
            # Downgrade reentrant lock to non-reentrant
            async with self._replica_lock(("replicate", object())):
                if workers is not None:
                    workers = {self.workers[w] for w in self.workers_list(workers)}
                    workers = {ws for ws in workers if ws.status == Status.running}
                else:
                    workers = self.running
    
                if n is None:
                    n = len(workers)
                else:
                    n = min(n, len(workers))
                if n == 0:
                    raise ValueError("Can not use replicate to delete data")
    
                tasks = {self.tasks[k] for k in keys}
                missing_data = [ts.key for ts in tasks if not ts.who_has]
                if missing_data:
                    return {"status": "partial-fail", "keys": missing_data}
    
                # Delete extraneous data
                if delete:
                    del_worker_tasks = defaultdict(set)
                    for ts in tasks:
                        del_candidates = tuple(ts.who_has & workers)
                        if len(del_candidates) > n:
                            for ws in random.sample(
                                del_candidates, len(del_candidates) - n
                            ):
                                del_worker_tasks[ws].add(ts)
    
                    # Note: this never raises exceptions
                    await asyncio.gather(
                        *[
                            self.delete_worker_data(
                                ws.address, [t.key for t in tasks], stimulus_id
                            )
                            for ws, tasks in del_worker_tasks.items()
                        ]
                    )
    
                # Copy not-yet-filled data
                while tasks:
                    gathers = defaultdict(dict)
                    for ts in list(tasks):
                        if ts.state == "forgotten":
                            # task is no longer needed by any client or dependent task
                            tasks.remove(ts)
                            continue
                        assert ts.who_has is not None
                        n_missing = n - len(ts.who_has & workers)
                        if n_missing <= 0:
                            # Already replicated enough
                            tasks.remove(ts)
                            continue
    
                        count = min(n_missing, branching_factor * len(ts.who_has))
                        assert count > 0
    
                        for ws in random.sample(tuple(workers - ts.who_has), count):
                            gathers[ws.address][ts.key] = [
                                wws.address for wws in ts.who_has
                            ]
    
                    await asyncio.gather(
                        *(
                            # Note: this never raises exceptions
                            self.gather_on_worker(w, who_has)
                            for w, who_has in gathers.items()
                        )
                    )
                    for r, v in gathers.items():
                        self.log_event(r, {"action": "replicate-add", "who_has": v})
    
                self.log_event(
                    "all",
                    {
                        "action": "replicate",
                        "workers": list(workers),
                        "key-count": len(keys),
                        "branching-factor": branching_factor,
                    },
                )
    
        @log_errors
        def workers_to_close(
            self,
            memory_ratio: int | float | None = None,
            n: int | None = None,
            key: Callable[[WorkerState], Hashable] | bytes | None = None,
            minimum: int | None = None,
            target: int | None = None,
            attribute: str = "address",
        ) -> list[str]:
            """
            Find workers that we can close with low cost
    
            This returns a list of workers that are good candidates to retire.
            These workers are not running anything and are storing
            relatively little data relative to their peers.  If all workers are
            idle then we still maintain enough workers to have enough RAM to store
            our data, with a comfortable buffer.
    
            This is for use with systems like ``distributed.deploy.adaptive``.
    
            Parameters
            ----------
            memory_ratio : Number
                Amount of extra space we want to have for our stored data.
                Defaults to 2, or that we want to have twice as much memory as we
                currently have data.
            n : int
                Number of workers to close
            minimum : int
                Minimum number of workers to keep around
            key : Callable(WorkerState)
                An optional callable mapping a WorkerState object to a group
                affiliation. Groups will be closed together. This is useful when
                closing workers must be done collectively, such as by hostname.
            target : int
                Target number of workers to have after we close
            attribute : str
                The attribute of the WorkerState object to return, like "address"
                or "name".  Defaults to "address".
    
            Examples
            --------
            >>> scheduler.workers_to_close()
            ['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234']
    
            Group workers by hostname prior to closing
    
            >>> scheduler.workers_to_close(key=lambda ws: ws.host)
            ['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567']
    
            Remove two workers
    
            >>> scheduler.workers_to_close(n=2)
    
            Keep enough workers to have twice as much memory as we we need.
    
            >>> scheduler.workers_to_close(memory_ratio=2)
    
            Returns
            -------
            to_close: list of worker addresses that are OK to close
    
            See Also
            --------
            Scheduler.retire_workers
            """
            if target is not None and n is None:
                n = len(self.workers) - target
            if n is not None:
                if n < 0:
                    n = 0
                target = len(self.workers) - n
    
            if n is None and memory_ratio is None:
                memory_ratio = 2
    
            if not n and all([ws.processing for ws in self.workers.values()]):
                return []
    
            if key is None:
                key = operator.attrgetter("address")
            if isinstance(key, bytes):
                key = pickle.loads(key)
    
            # Long running tasks typically use a worker_client to schedule
            # other tasks. We should never shut down the worker they're
            # running on, as it would cause them to restart from scratch
            # somewhere else.
            valid_workers = [ws for ws in self.workers.values() if not ws.long_running]
            for plugin in list(self.plugins.values()):
                valid_workers = plugin.valid_workers_downscaling(self, valid_workers)
    
            groups = groupby(key, valid_workers)
    
            limit_bytes = {k: sum(ws.memory_limit for ws in v) for k, v in groups.items()}
            group_bytes = {k: sum(ws.nbytes for ws in v) for k, v in groups.items()}
    
            limit = sum(limit_bytes.values())
            total = sum(group_bytes.values())
    
            def _key(group):
                is_idle = not any([wws.processing for wws in groups[group]])
                bytes = -group_bytes[group]
                return is_idle, bytes
    
            idle = sorted(groups, key=_key)
    
            to_close = []
            n_remain = len(self.workers)
    
            while idle:
                group = idle.pop()
                if n is None and any([ws.processing for ws in groups[group]]):
                    break
    
                if minimum and n_remain - len(groups[group]) < minimum:
                    break
    
                limit -= limit_bytes[group]
    
                if (n is not None and n_remain - len(groups[group]) >= (target or 0)) or (
                    memory_ratio is not None and limit >= memory_ratio * total
                ):
                    to_close.append(group)
                    n_remain -= len(groups[group])
    
                else:
                    break
    
            result = [getattr(ws, attribute) for g in to_close for ws in groups[g]]
            if result:
                logger.debug("Suggest closing workers: %s", result)
    
            return result
    
        @overload
        async def retire_workers(
            self,
            workers: list[str],
            *,
            close_workers: bool = False,
            remove: bool = True,
            stimulus_id: str | None = None,
        ) -> list[str]: ...
    
        @overload
        async def retire_workers(
            self,
            *,
            names: list,
            close_workers: bool = False,
            remove: bool = True,
            stimulus_id: str | None = None,
        ) -> list[str]: ...
    
        @overload
        async def retire_workers(
            self,
            *,
            close_workers: bool = False,
            remove: bool = True,
            stimulus_id: str | None = None,
            # Parameters for workers_to_close()
            memory_ratio: int | float | None = None,
            n: int | None = None,
            key: Callable[[WorkerState], Hashable] | bytes | None = None,
            minimum: int | None = None,
            target: int | None = None,
            attribute: str = "address",
        ) -> list[str]: ...
    
        @log_errors
        async def retire_workers(
            self,
            workers: list[str] | None = None,
            *,
            names: list | None = None,
            close_workers: bool = False,
            remove: bool = True,
            stimulus_id: str | None = None,
            **kwargs: Any,
        ) -> list[str]:
            """Gracefully retire workers from cluster. Any key that is in memory exclusively
            on the retired workers is replicated somewhere else.
    
            Parameters
            ----------
            workers: list[str] (optional)
                List of worker addresses to retire.
            names: list (optional)
                List of worker names to retire.
                Mutually exclusive with ``workers``.
                If neither ``workers`` nor ``names`` are provided, we call
                ``workers_to_close`` which finds a good set.
            close_workers: bool (defaults to False)
                Whether to actually close the worker explicitly from here.
                Otherwise, we expect some external job scheduler to finish off the worker.
            remove: bool (defaults to True)
                Whether to remove the worker metadata immediately or else wait for the
                worker to contact us.
    
                If close_workers=False and remove=False, this method just flushes the tasks
                in memory out of the workers and then returns.
                If close_workers=True and remove=False, this method will return while the
                workers are still in the cluster, although they won't accept new tasks.
                If close_workers=False or for whatever reason a worker doesn't accept the
                close command, it will be left permanently unable to accept new tasks and
                it is expected to be closed in some other way.
    
            **kwargs: dict
                Extra options to pass to workers_to_close to determine which
                workers we should drop. Only accepted if ``workers`` and ``names`` are
                omitted.
    
            Returns
            -------
            Dictionary mapping worker ID/address to dictionary of information about
            that worker for each retired worker.
    
            If there are keys that exist in memory only on the workers being retired and it
            was impossible to replicate them somewhere else (e.g. because there aren't
            any other running workers), the workers holding such keys won't be retired and
            won't appear in the returned dict.
    
            See Also
            --------
            Scheduler.workers_to_close
            """
            if names is not None and workers is not None:
                raise TypeError("names and workers are mutually exclusive")
            if (names is not None or workers is not None) and kwargs:
                raise TypeError(
                    "Parameters for workers_to_close() are mutually exclusive with "
                    f"names and workers: {kwargs}"
                )
    
            stimulus_id = stimulus_id or f"retire-workers-{time()}"
            # This lock makes retire_workers, rebalance, and replicate mutually
            # exclusive and will no longer be necessary once rebalance and replicate are
            # migrated to the Active Memory Manager.
            # However, it allows multiple instances of retire_workers to run in parallel.
            async with self._replica_lock("retire-workers"):
                if names is not None:
                    logger.info("Retire worker names %s", names)
                    # Support cases where names are passed through a CLI and become strings
                    names_set = {str(name) for name in names}
                    wss = {ws for ws in self.workers.values() if str(ws.name) in names_set}
                elif workers is not None:
                    logger.info(
                        "Retire worker addresses (stimulus_id='%s') %s",
                        stimulus_id,
                        workers,
                    )
                    wss = {
                        self.workers[address]
                        for address in workers
                        if address in self.workers
                    }
                else:
                    wss = {
                        self.workers[address] for address in self.workers_to_close(**kwargs)
                    }
                if not wss:
                    return []
    
                stop_amm = False
                amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm")
                if not amm or not amm.running:
                    amm = ActiveMemoryManagerExtension(
                        self, policies=set(), register=False, start=True, interval=2.0
                    )
                    stop_amm = True
    
                try:
                    coros = []
                    for ws in wss:
                        policy = RetireWorker(ws.address)
                        amm.add_policy(policy)
    
                        # Change Worker.status to closing_gracefully. Immediately set
                        # the same on the scheduler to prevent race conditions.
                        prev_status = ws.status
                        self.handle_worker_status_change(
                            Status.closing_gracefully, ws, stimulus_id
                        )
                        # FIXME: We should send a message to the nanny first;
                        # eventually workers won't be able to close their own nannies.
                        self.stream_comms[ws.address].send(
                            {
                                "op": "worker-status-change",
                                "status": ws.status.name,
                                "stimulus_id": stimulus_id,
                            }
                        )
    
                        coros.append(
                            self._track_retire_worker(
                                ws,
                                policy,
                                prev_status=prev_status,
                                close=close_workers,
                                remove=remove,
                                stimulus_id=stimulus_id,
                            )
                        )
    
                    # Give the AMM a kick, in addition to its periodic running. This is
                    # to avoid unnecessarily waiting for a potentially arbitrarily long
                    # time (depending on interval settings)
                    amm.run_once()
    
                    workers_info_ok = []
                    workers_info_abort = []
                    for addr, result in await asyncio.gather(*coros):
                        if result == "OK":
                            workers_info_ok.append(addr)
                        else:
                            workers_info_abort.append(addr)
    
                finally:
                    if stop_amm:
                        amm.stop()
    
            self.log_event(
                "all",
                {
                    "action": "retire-workers",
                    "retired": list(workers_info_ok),
                    "could-not-retire": list(workers_info_abort),
                    "stimulus_id": stimulus_id,
                },
            )
            self.log_event(
                list(workers_info_ok),
                {"action": "retired", "stimulus_id": stimulus_id},
            )
            self.log_event(
                list(workers_info_abort),
                {"action": "could-not-retire", "stimulus_id": stimulus_id},
            )
    
            return workers_info_ok
    
        async def _track_retire_worker(
            self,
            ws: WorkerState,
            policy: RetireWorker,
            prev_status: Status,
            close: bool,
            remove: bool,
            stimulus_id: str,
        ) -> tuple[str, Literal["OK", "no-recipients"]]:
            while not policy.done():
                # Sleep 0.01s when there are 4 tasks or less
                # Sleep 0.5s when there are 200 or more
                poll_interval = max(0.01, min(0.5, len(ws.has_what) / 400))
                await asyncio.sleep(poll_interval)
    
            if policy.no_recipients:
                # Abort retirement. This time we don't need to worry about race
                # conditions and we can wait for a scheduler->worker->scheduler
                # round-trip.
                self.stream_comms[ws.address].send(
                    {
                        "op": "worker-status-change",
                        "status": prev_status.name,
                        "stimulus_id": stimulus_id,
                    }
                )
                logger.warning(
                    f"Could not retire worker {ws.address!r}: unique data could not be "
                    f"moved to any other worker ({stimulus_id=!r})"
                )
                return ws.address, "no-recipients"
    
            logger.debug(
                f"All unique keys on worker {ws.address!r} have been replicated elsewhere"
            )
    
            if remove:
                await self.remove_worker(
                    ws.address, expected=True, close=close, stimulus_id=stimulus_id
                )
            elif close:
                self.close_worker(ws.address)
    
            logger.info(f"Retired worker {ws.address!r} ({stimulus_id=!r})")
            return ws.address, "OK"
    
        def add_keys(
            self, worker: str, keys: Collection[Key] = (), stimulus_id: str | None = None
        ) -> Literal["OK", "not found"]:
            """
            Learn that a worker has certain keys
    
            This should not be used in practice and is mostly here for legacy
            reasons.  However, it is sent by workers from time to time.
            """
            if worker not in self.workers:
                return "not found"
            ws = self.workers[worker]
            redundant_replicas = []
            for key in keys:
                ts = self.tasks.get(key)
                if ts is not None and ts.state == "memory":
                    self.add_replica(ts, ws)
                else:
                    redundant_replicas.append(key)
    
            if redundant_replicas:
                if not stimulus_id:
                    stimulus_id = f"redundant-replicas-{time()}"
                self.worker_send(
                    worker,
                    {
                        "op": "remove-replicas",
                        "keys": redundant_replicas,
                        "stimulus_id": stimulus_id,
                    },
                )
    
            return "OK"
    
        @log_errors
        def update_data(
            self,
            *,
            who_has: dict[Key, list[str]],
            nbytes: dict[Key, int],
            client: str | None = None,
        ) -> None:
            """Learn that new data has entered the network from an external source"""
            who_has = {k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items()}
            logger.debug("Update data %s", who_has)
    
            for key, workers in who_has.items():
                ts = self.tasks.get(key)
                if ts is None:
                    ts = self.new_task(key, None, "memory")
                ts.state = "memory"
                ts_nbytes = nbytes.get(key, -1)
                if ts_nbytes >= 0:
                    ts.set_nbytes(ts_nbytes)
    
                for w in workers:
                    ws = self.workers[w]
                    self.add_replica(ts, ws)
                self.report({"op": "key-in-memory", "key": key, "workers": list(workers)})
    
            if client:
                self.client_desires_keys(keys=list(who_has), client=client)
    
        @overload
        def report_on_key(self, key: Key, *, client: str | None = None) -> None: ...
    
        @overload
        def report_on_key(self, *, ts: TaskState, client: str | None = None) -> None: ...
    
        def report_on_key(self, key=None, *, ts=None, client=None):
            if (ts is None) == (key is None):
                raise ValueError(  # pragma: nocover
                    f"ts and key are mutually exclusive; received {key=!r}, {ts=!r}"
                )
            if ts is None:
                assert key is not None
                ts = self.tasks.get(key)
            else:
                key = ts.key
    
            if ts is not None:
                report_msg = _task_to_report_msg(ts)
            else:
                report_msg = {"op": "cancelled-keys", "keys": [key]}
            if report_msg is not None:
                self.report(report_msg, ts=ts, client=client)
    
        @log_errors
        async def feed(
            self,
            comm: Comm,
            function: bytes | None = None,
            setup: bytes | None = None,
            teardown: bytes | None = None,
            interval: str | float = "1s",
            **kwargs: Any,
        ) -> None:
            """
            Provides a data Comm to external requester
    
            Caution: this runs arbitrary Python code on the scheduler.  This should
            eventually be phased out.  It is mostly used by diagnostics.
            """
    
            interval = parse_timedelta(interval)
            if function:
                function = pickle.loads(function)
            if setup:
                setup = pickle.loads(setup)
    
            if teardown:
                teardown = pickle.loads(teardown)
            state = setup(self) if setup else None  # type: ignore
            if inspect.isawaitable(state):
                state = await state
            try:
                while self.status == Status.running:
                    if state is None:
                        response = function(self)  # type: ignore
                    else:
                        response = function(self, state)  # type: ignore
                    await comm.write(response)
                    await asyncio.sleep(interval)
            except OSError:
                pass
            finally:
                if teardown:
                    teardown(self, state)  # type: ignore
    
        def log_worker_event(
            self, worker: str, topic: str | Collection[str], msg: Any
        ) -> None:
            if isinstance(msg, dict) and worker != topic:
                msg["worker"] = worker
            self.log_event(topic, msg)
    
        def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]:
            WorkerStatusPlugin(self, comm)
            ident = self.identity()
            for v in ident["workers"].values():
                del v["metrics"]
                del v["last_seen"]
            return ident
    
        def get_processing(
            self, workers: Iterable[str] | None = None
        ) -> dict[str, list[Key]]:
            if workers is not None:
                workers = set(map(self.coerce_address, workers))
                return {w: [ts.key for ts in self.workers[w].processing] for w in workers}
            else:
                return {
                    w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
                }
    
        def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]:
            if keys is not None:
                return {
                    key: (
                        [ws.address for ws in self.tasks[key].who_has or ()]
                        if key in self.tasks
                        else []
                    )
                    for key in keys
                }
            else:
                return {
                    key: [ws.address for ws in ts.who_has or ()]
                    for key, ts in self.tasks.items()
                }
    
        def get_has_what(
            self, workers: Iterable[str] | None = None
        ) -> dict[str, list[Key]]:
            if workers is not None:
                workers = map(self.coerce_address, workers)
                return {
                    w: (
                        [ts.key for ts in self.workers[w].has_what]
                        if w in self.workers
                        else []
                    )
                    for w in workers
                }
            else:
                return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()}
    
        def get_ncores(self, workers: Iterable[str] | None = None) -> dict[str, int]:
            if workers is not None:
                workers = map(self.coerce_address, workers)
                return {w: self.workers[w].nthreads for w in workers if w in self.workers}
            else:
                return {w: ws.nthreads for w, ws in self.workers.items()}
    
        def get_ncores_running(
            self, workers: Iterable[str] | None = None
        ) -> dict[str, int]:
            ncores = self.get_ncores(workers=workers)
            return {
                w: n for w, n in ncores.items() if self.workers[w].status == Status.running
            }
    
        async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, Any]:
            workers: dict[str, list[Key] | None]
            if keys is not None:
                stack = list(keys)
                processing = set()
                while stack:
                    key = stack.pop()
>                   ts = self.tasks[key]
E                   KeyError: 'slowinc-1ea6f50d354d90f0a53d5f5250e2c25f'

distributed\scheduler.py:7903: KeyError

Check warning on line 0 in distributed.tests.test_client

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

11 out of 12 runs failed: test_release_persisted_collection (distributed.tests.test_client)

artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36527', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:46067', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:38837', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>

    @gen_cluster(client=True)
    async def test_release_persisted_collection(c, s, a, b):
        np = pytest.importorskip("numpy")
        da = pytest.importorskip("dask.array")
    
        arr = c.persist(da.random.random((10,), chunks=(10,)))
    
        await wait(arr)
    
        _release_persisted(arr)
        while s.tasks:
            await asyncio.sleep(0.01)
    
        with pytest.raises(CancelledError):
>           await c.compute(arr)

distributed/tests/test_client.py:8235: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1328: in finalize
    return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5496: in concatenate3
    chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: in chunks_from_arrays
    result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import contextlib
    import math
    import operator
    import os
    import pickle
    import re
    import sys
    import traceback
    import uuid
    import warnings
    from bisect import bisect
    from collections import defaultdict
    from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
    from functools import lru_cache, partial, reduce, wraps
    from itertools import product, zip_longest
    from numbers import Integral, Number
    from operator import add, mul
    from threading import Lock
    from typing import Any, Literal, TypeVar, Union, cast
    
    import numpy as np
    from numpy.typing import ArrayLike
    from packaging.version import Version
    from tlz import accumulate, concat, first, groupby, partition
    from tlz.curried import pluck
    from toolz import frequencies
    
    from dask import compute, config, core
    from dask.array import chunk
    from dask.array.chunk import getitem
    from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
    
    # Keep einsum_lookup and tensordot_lookup here for backwards compatibility
    from dask.array.dispatch import (  # noqa: F401
        concatenate_lookup,
        einsum_lookup,
        tensordot_lookup,
    )
    from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
    from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
    from dask.array.utils import compute_meta, meta_from_array
    from dask.base import (
        DaskMethodsMixin,
        compute_as_if_collection,
        dont_optimize,
        is_dask_collection,
        named_schedulers,
        persist,
        tokenize,
    )
    from dask.blockwise import blockwise as core_blockwise
    from dask.blockwise import broadcast_dimensions
    from dask.context import globalmethod
    from dask.core import quote
    from dask.delayed import Delayed, delayed
    from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
    from dask.layers import ArraySliceDep, reshapelist
    from dask.sizeof import sizeof
    from dask.typing import Graph, Key, NestedKeys
    from dask.utils import (
        IndexCallable,
        SerializableLock,
        cached_cumsum,
        cached_property,
        concrete,
        derived_from,
        format_bytes,
        funcname,
        has_keyword,
        is_arraylike,
        is_dataframe_like,
        is_index_like,
        is_integer,
        is_series_like,
        maybe_pluralize,
        ndeepmap,
        ndimlist,
        parse_bytes,
        typename,
    )
    from dask.widgets import get_template
    
    T_IntOrNaN = Union[int, float]  # Should be Union[int, Literal[np.nan]]
    
    DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
    
    unknown_chunk_message = (
        "\n\n"
        "A possible solution: "
        "https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
        "Summary: to compute chunks sizes, use\n\n"
        "   x.compute_chunk_sizes()  # for Dask Array `x`\n"
        "   ddf.to_dask_array(lengths=True)  # for Dask DataFrame `ddf`"
    )
    
    
    class PerformanceWarning(Warning):
        """A warning given when bad chunking may cause poor performance"""
    
    
    def getter(a, b, asarray=True, lock=None):
        if isinstance(b, tuple) and any(x is None for x in b):
            b2 = tuple(x for x in b if x is not None)
            b3 = tuple(
                None if x is None else slice(None, None)
                for x in b
                if not isinstance(x, Integral)
            )
            return getter(a, b2, asarray=asarray, lock=lock)[b3]
    
        if lock:
            lock.acquire()
        try:
            c = a[b]
            # Below we special-case `np.matrix` to force a conversion to
            # `np.ndarray` and preserve original Dask behavior for `getter`,
            # as for all purposes `np.matrix` is array-like and thus
            # `is_arraylike` evaluates to `True` in that case.
            if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
                c = np.asarray(c)
        finally:
            if lock:
                lock.release()
        return c
    
    
    def getter_nofancy(a, b, asarray=True, lock=None):
        """A simple wrapper around ``getter``.
    
        Used to indicate to the optimization passes that the backend doesn't
        support fancy indexing.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    def getter_inline(a, b, asarray=True, lock=None):
        """A getter function that optimizations feel comfortable inlining
    
        Slicing operations with this function may be inlined into a graph, such as
        in the following rewrite
    
        **Before**
    
        >>> a = x[:10]  # doctest: +SKIP
        >>> b = a + 1  # doctest: +SKIP
        >>> c = a * 2  # doctest: +SKIP
    
        **After**
    
        >>> b = x[:10] + 1  # doctest: +SKIP
        >>> c = x[:10] * 2  # doctest: +SKIP
    
        This inlining can be relevant to operations when running off of disk.
        """
        return getter(a, b, asarray=asarray, lock=lock)
    
    
    from dask.array.optimization import fuse_slice, optimize
    
    # __array_function__ dict for mapping aliases and mismatching names
    _HANDLED_FUNCTIONS = {}
    
    
    def implements(*numpy_functions):
        """Register an __array_function__ implementation for dask.array.Array
    
        Register that a function implements the API of a NumPy function (or several
        NumPy functions in case of aliases) which is handled with
        ``__array_function__``.
    
        Parameters
        ----------
        \\*numpy_functions : callables
            One or more NumPy functions that are handled by ``__array_function__``
            and will be mapped by `implements` to a `dask.array` function.
        """
    
        def decorator(dask_func):
            for numpy_function in numpy_functions:
                _HANDLED_FUNCTIONS[numpy_function] = dask_func
    
            return dask_func
    
        return decorator
    
    
    def _should_delegate(self, other) -> bool:
        """Check whether Dask should delegate to the other.
        This implementation follows NEP-13:
        https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
        """
        if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
            return True
        elif (
            hasattr(other, "__array_ufunc__")
            and not is_valid_array_chunk(other)
            # don't delegate to our own parent classes
            and not isinstance(self, type(other))
            and type(self) is not type(other)
        ):
            return True
        return False
    
    
    def check_if_handled_given_other(f):
        """Check if method is handled by Dask given type of other
    
        Ensures proper deferral to upcast types in dunder operations without
        assuming unknown types are automatically downcast types.
        """
    
        @wraps(f)
        def wrapper(self, other):
            if _should_delegate(self, other):
                return NotImplemented
            else:
                return f(self, other)
    
        return wrapper
    
    
    def slices_from_chunks(chunks):
        """Translate chunks tuple to a set of slices in product order
    
        >>> slices_from_chunks(((2, 2), (3, 3, 3)))  # doctest: +NORMALIZE_WHITESPACE
         [(slice(0, 2, None), slice(0, 3, None)),
          (slice(0, 2, None), slice(3, 6, None)),
          (slice(0, 2, None), slice(6, 9, None)),
          (slice(2, 4, None), slice(0, 3, None)),
          (slice(2, 4, None), slice(3, 6, None)),
          (slice(2, 4, None), slice(6, 9, None))]
        """
        cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
        slices = [
            [slice(s, s + dim) for s, dim in zip(starts, shapes)]
            for starts, shapes in zip(cumdims, chunks)
        ]
        return list(product(*slices))
    
    
    def graph_from_arraylike(
        arr,  # Any array-like which supports slicing
        chunks,
        shape,
        name,
        getitem=getter,
        lock=False,
        asarray=True,
        dtype=None,
        inline_array=False,
    ) -> HighLevelGraph:
        """
        HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
    
        If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
        array-like is embedded into every task.,
    
        If ``inline_array`` is False, this inserts the array-like as a standalone value in
        a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
        to it.
    
        >>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True))  # doctest: +SKIP
        {(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
         (arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
         (arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
         (arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
    
        >>> dict(  # doctest: +SKIP
                graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
            )
        {"original-X": arr,
         ('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
         ('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
         ('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
         ('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
        """
        chunks = normalize_chunks(chunks, shape, dtype=dtype)
        out_ind = tuple(range(len(shape)))
    
        if (
            has_keyword(getitem, "asarray")
            and has_keyword(getitem, "lock")
            and (not asarray or lock)
        ):
            kwargs = {"asarray": asarray, "lock": lock}
        else:
            # Common case, drop extra parameters
            kwargs = {}
    
        if inline_array:
            layer = core_blockwise(
                getitem,
                name,
                out_ind,
                arr,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
            return HighLevelGraph.from_collections(name, layer)
        else:
            original_name = "original-" + name
    
            layers = {}
            layers[original_name] = MaterializedLayer({original_name: arr})
            layers[name] = core_blockwise(
                getitem,
                name,
                out_ind,
                original_name,
                None,
                ArraySliceDep(chunks),
                out_ind,
                numblocks={},
                **kwargs,
            )
    
            deps = {
                original_name: set(),
                name: {original_name},
            }
            return HighLevelGraph(layers, deps)
    
    
    def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
        """Dot product of many aligned chunks
    
        >>> x = np.array([[1, 2], [1, 2]])
        >>> y = np.array([[10, 20], [10, 20]])
        >>> dotmany([x, x, x], [y, y, y])
        array([[ 90, 180],
               [ 90, 180]])
    
        Optionally pass in functions to apply to the left and right chunks
    
        >>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
        array([[150, 150],
               [150, 150]])
        """
        if leftfunc:
            A = map(leftfunc, A)
        if rightfunc:
            B = map(rightfunc, B)
        return sum(map(partial(np.dot, **kwargs), A, B))
    
    
    def _concatenate2(arrays, axes=None):
        """Recursively concatenate nested lists of arrays along axes
    
        Each entry in axes corresponds to each level of the nested list.  The
        length of axes should correspond to the level of nesting of arrays.
        If axes is an empty list or tuple, return arrays, or arrays[0] if
        arrays is a list.
    
        >>> x = np.array([[1, 2], [3, 4]])
        >>> _concatenate2([x, x], axes=[0])
        array([[1, 2],
               [3, 4],
               [1, 2],
               [3, 4]])
    
        >>> _concatenate2([x, x], axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        >>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4],
               [1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Supports Iterators
        >>> _concatenate2(iter([x, x]), axes=[1])
        array([[1, 2, 1, 2],
               [3, 4, 3, 4]])
    
        Special Case
        >>> _concatenate2([x, x], axes=())
        array([[1, 2],
               [3, 4]])
        """
        if axes is None:
            axes = []
    
        if axes == ():
            if isinstance(arrays, list):
                return arrays[0]
            else:
                return arrays
    
        if isinstance(arrays, Iterator):
            arrays = list(arrays)
        if not isinstance(arrays, (list, tuple)):
            return arrays
        if len(axes) > 1:
            arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
        concatenate = concatenate_lookup.dispatch(
            type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        if isinstance(arrays[0], dict):
            # Handle concatenation of `dict`s, used as a replacement for structured
            # arrays when that's not supported by the array library (e.g., CuPy).
            keys = list(arrays[0].keys())
            assert all(list(a.keys()) == keys for a in arrays)
            ret = dict()
            for k in keys:
                ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
            return ret
        else:
            return concatenate(arrays, axis=axes[0])
    
    
    def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
        """
        Tries to infer output dtype of ``func`` for a small set of input arguments.
    
        Parameters
        ----------
        func: Callable
            Function for which output dtype is to be determined
    
        args: List of array like
            Arguments to the function, which would usually be used. Only attributes
            ``ndim`` and ``dtype`` are used.
    
        kwargs: dict
            Additional ``kwargs`` to the ``func``
    
        funcname: String
            Name of calling function to improve potential error messages
    
        suggest_dtype: None/False or String
            If not ``None`` adds suggestion to potential error message to specify a dtype
            via the specified kwarg. Defaults to ``'dtype'``.
    
        nout: None or Int
            ``None`` if function returns single output, integer if many.
            Defaults to ``None``.
    
        Returns
        -------
        : dtype or List of dtype
            One or many dtypes (depending on ``nout``)
        """
        from dask.array.utils import meta_from_array
    
        # make sure that every arg is an evaluated array
        args = [
            (
                np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
                if is_arraylike(x)
                else x
            )
            for x in args
        ]
        try:
            with np.errstate(all="ignore"):
                o = func(*args, **kwargs)
        except Exception as e:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            tb = "".join(traceback.format_tb(exc_traceback))
            suggest = (
                (
                    "Please specify the dtype explicitly using the "
                    "`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
                )
                if suggest_dtype
                else ""
            )
            msg = (
                f"`dtype` inference failed in `{funcname}`.\n\n"
                f"{suggest}"
                "Original error is below:\n"
                "------------------------\n"
                f"{e!r}\n\n"
                "Traceback:\n"
                "---------\n"
                f"{tb}"
            )
        else:
            msg = None
        if msg is not None:
            raise ValueError(msg)
        return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
    
    
    def normalize_arg(x):
        """Normalize user provided arguments to blockwise or map_blocks
    
        We do a few things:
    
        1.  If they are string literals that might collide with blockwise_token then we
            quote them
        2.  IF they are large (as defined by sizeof) then we put them into the
            graph on their own by using dask.delayed
        """
        if is_dask_collection(x):
            return x
        elif isinstance(x, str) and re.match(r"_\d+", x):
            return delayed(x)
        elif isinstance(x, list) and len(x) >= 10:
            return delayed(x)
        elif sizeof(x) > 1e6:
            return delayed(x)
        else:
            return x
    
    
    def _pass_extra_kwargs(func, keys, *args, **kwargs):
        """Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
    
        For each element of `keys`, a corresponding element of args is changed
        to a keyword argument with that key, before all arguments re passed on
        to `func`.
        """
        kwargs.update(zip(keys, args))
        return func(*args[len(keys) :], **kwargs)
    
    
    def map_blocks(
        func,
        *args,
        name=None,
        token=None,
        dtype=None,
        chunks=None,
        drop_axis=None,
        new_axis=None,
        enforce_ndim=False,
        meta=None,
        **kwargs,
    ):
        """Map a function across all blocks of a dask array.
    
        Note that ``map_blocks`` will attempt to automatically determine the output
        array type by calling ``func`` on 0-d versions of the inputs. Please refer to
        the ``meta`` keyword argument below if you expect that the function will not
        succeed when operating on 0-d arrays.
    
        Parameters
        ----------
        func : callable
            Function to apply to every block in the array.
            If ``func`` accepts ``block_info=`` or ``block_id=``
            as keyword arguments, these will be passed dictionaries
            containing information about input and output chunks/arrays
            during computation. See examples for details.
        args : dask arrays or other objects
        dtype : np.dtype, optional
            The ``dtype`` of the output array. It is recommended to provide this.
            If not provided, will be inferred by applying the function to a small
            set of fake data.
        chunks : tuple, optional
            Chunk shape of resulting blocks if the function does not preserve
            shape. If not provided, the resulting array is assumed to have the same
            block structure as the first input array.
        drop_axis : number or iterable, optional
            Dimensions lost by the function.
        new_axis : number or iterable, optional
            New dimensions created by the function. Note that these are applied
            after ``drop_axis`` (if present). The size of each chunk along this
            dimension will be set to 1. Please specify ``chunks`` if the individual
            chunks have a different size.
        enforce_ndim : bool, default False
            Whether to enforce at runtime that the dimensionality of the array
            produced by ``func`` actually matches that of the array returned by
            ``map_blocks``.
            If True, this will raise an error when there is a mismatch.
        token : string, optional
            The key prefix to use for the output array. If not provided, will be
            determined from the function name.
        name : string, optional
            The key name to use for the output array. Note that this fully
            specifies the output key name, and must be unique. If not provided,
            will be determined by a hash of the arguments.
        meta : array-like, optional
            The ``meta`` of the output array, when specified is expected to be an
            array of the same type and dtype of that returned when calling ``.compute()``
            on the array returned by this function. When not provided, ``meta`` will be
            inferred by applying the function to a small set of fake data, usually a
            0-d array. It's important to ensure that ``func`` can successfully complete
            computation without raising exceptions when 0-d is passed to it, providing
            ``meta`` will be required otherwise. If the output type is known beforehand
            (e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
            can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
        **kwargs :
            Other keyword arguments to pass to function. Values must be constants
            (not dask.arrays)
    
        See Also
        --------
        dask.array.map_overlap : Generalized operation with overlap between neighbors.
        dask.array.blockwise : Generalized operation with control over block alignment.
    
        Examples
        --------
        >>> import dask.array as da
        >>> x = da.arange(6, chunks=3)
    
        >>> x.map_blocks(lambda x: x * 2).compute()
        array([ 0,  2,  4,  6,  8, 10])
    
        The ``da.map_blocks`` function can also accept multiple arrays.
    
        >>> d = da.arange(5, chunks=2)
        >>> e = da.arange(5, chunks=2)
    
        >>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
        >>> f.compute()
        array([ 0,  2,  6, 12, 20])
    
        If the function changes shape of the blocks then you must provide chunks
        explicitly.
    
        >>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
    
        You have a bit of freedom in specifying chunks.  If all of the output chunk
        sizes are the same, you can provide just that chunk size as a single tuple.
    
        >>> a = da.arange(18, chunks=(6,))
        >>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
    
        If the function changes the dimension of the blocks you must specify the
        created or destroyed dimensions.
    
        >>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
        ...                  new_axis=[0, 2])
    
        If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
        add the necessary number of axes on the left.
    
        Note that ``map_blocks()`` will concatenate chunks along axes specified by
        the keyword parameter ``drop_axis`` prior to applying the function.
        This is illustrated in the figure below:
    
        .. image:: /images/map_blocks_drop_axis.png
    
        Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
        on an axis that is chunked.  In that case, it is better not to use
        ``map_blocks`` but rather
        ``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
        maintains a leaner memory footprint while it drops any axis.
    
        Map_blocks aligns blocks by block positions without regard to shape. In the
        following example we have two arrays with the same number of blocks but
        with different shape and chunk sizes.
    
        >>> x = da.arange(1000, chunks=(100,))
        >>> y = da.arange(100, chunks=(10,))
    
        The relevant attribute to match is numblocks.
    
        >>> x.numblocks
        (10,)
        >>> y.numblocks
        (10,)
    
        If these match (up to broadcasting rules) then we can map arbitrary
        functions across blocks
    
        >>> def func(a, b):
        ...     return np.array([a.max(), b.max()])
    
        >>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
        dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([ 99,   9, 199,  19, 299,  29, 399,  39, 499,  49, 599,  59, 699,
                69, 799,  79, 899,  89, 999,  99])
    
        Your block function can get information about where it is in the array by
        accepting a special ``block_info`` or ``block_id`` keyword argument.
        During computation, they will contain information about each of the input
        and output chunks (and dask arrays) relevant to each call of ``func``.
    
        >>> def func(block_info=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_info  # doctest: +SKIP
        {0: {'shape': (1000,),
             'num-chunks': (10,),
             'chunk-location': (4,),
             'array-location': [(400, 500)]},
         None: {'shape': (1000,),
                'num-chunks': (10,),
                'chunk-location': (4,),
                'array-location': [(400, 500)],
                'chunk-shape': (100,),
                'dtype': dtype('float64')}}
    
        The keys to the ``block_info`` dictionary indicate which is the input and
        output Dask array:
    
        - **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
          The dictionary key is ``0`` because that is the argument index corresponding
          to the first input Dask array.
          In cases where multiple Dask arrays have been passed as input to the function,
          you can access them with the number corresponding to the input argument,
          eg: ``block_info[1]``, ``block_info[2]``, etc.
          (Note that if you pass multiple Dask arrays as input to map_blocks,
          the arrays must match each other by having matching numbers of chunks,
          along corresponding dimensions up to broadcasting rules.)
        - **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
          and contains information about the output chunks.
          The output chunk shape and dtype may may be different than the input chunks.
    
        For each dask array, ``block_info`` describes:
    
        - ``shape``: the shape of the full Dask array,
        - ``num-chunks``: the number of chunks of the full array in each dimension,
        - ``chunk-location``: the chunk location (for example the fourth chunk over
          in the first dimension), and
        - ``array-location``: the array location within the full Dask array
          (for example the slice corresponding to ``40:50``).
    
        In addition to these, there are two extra parameters described by
        ``block_info`` for the output array (in ``block_info[None]``):
    
        - ``chunk-shape``: the output chunk shape, and
        - ``dtype``: the output dtype.
    
        These features can be combined to synthesize an array from scratch, for
        example:
    
        >>> def func(block_info=None):
        ...     loc = block_info[None]['array-location'][0]
        ...     return np.arange(loc[0], loc[1])
    
        >>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
        dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
    
        >>> _.compute()
        array([0, 1, 2, 3, 4, 5, 6, 7])
    
        ``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
    
        >>> def func(block_id=None):
        ...     pass
    
        This will receive the following information:
    
        >>> block_id  # doctest: +SKIP
        (4, 3)
    
        You may specify the key name prefix of the resulting task in the graph with
        the optional ``token`` keyword argument.
    
        >>> x.map_blocks(lambda x: x + 1, name='increment')
        dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
    
        For functions that may not handle 0-d arrays, it's also possible to specify
        ``meta`` with an empty array matching the type of the expected result. In
        the example below, ``func`` will result in an ``IndexError`` when computing
        ``meta``:
    
        >>> rng = da.random.default_rng()
        >>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
        dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
    
        Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
        a ``dtype``:
    
        >>> import cupy  # doctest: +SKIP
        >>> rng = da.random.default_rng(cupy.random.default_rng())  # doctest: +SKIP
        >>> dt = np.float32
        >>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt))  # doctest: +SKIP
        dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
        """
        if drop_axis is None:
            drop_axis = []
    
        if not callable(func):
            msg = (
                "First argument must be callable function, not %s\n"
                "Usage:   da.map_blocks(function, x)\n"
                "   or:   da.map_blocks(function, x, y, z)"
            )
            raise TypeError(msg % type(func).__name__)
        if token:
            warnings.warn(
                "The `token=` keyword to `map_blocks` has been moved to `name=`. "
                "Please use `name=` instead as the `token=` keyword will be removed "
                "in a future release.",
                category=FutureWarning,
            )
            name = token
    
        name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
        new_axes = {}
    
        if isinstance(drop_axis, Number):
            drop_axis = [drop_axis]
        if isinstance(new_axis, Number):
            new_axis = [new_axis]  # TODO: handle new_axis
    
        arrs = [a for a in args if isinstance(a, Array)]
    
        argpa…tack
        """
        from dask.array import wrap
    
        seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
    
        if not seq:
            raise ValueError("Need array(s) to concatenate")
    
        if axis is None:
            seq = [a.flatten() for a in seq]
            axis = 0
    
        seq_metas = [meta_from_array(s) for s in seq]
        _concatenate = concatenate_lookup.dispatch(
            type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
        )
        meta = _concatenate(seq_metas, axis=axis)
    
        # Promote types to match meta
        seq = [a.astype(meta.dtype) for a in seq]
    
        # Find output array shape
        ndim = len(seq[0].shape)
        shape = tuple(
            sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
            for i in range(ndim)
        )
    
        # Drop empty arrays
        seq2 = [a for a in seq if a.size]
        if not seq2:
            seq2 = seq
    
        if axis < 0:
            axis = ndim + axis
        if axis >= ndim:
            msg = (
                "Axis must be less than than number of dimensions"
                "\nData has %d dimensions, but got axis=%d"
            )
            raise ValueError(msg % (ndim, axis))
    
        n = len(seq2)
        if n == 0:
            try:
                return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
            except TypeError:
                return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
        elif n == 1:
            return seq2[0]
    
        if not allow_unknown_chunksizes and not all(
            i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
            for i in range(ndim)
        ):
            if any(map(np.isnan, seq2[0].shape)):
                raise ValueError(
                    "Tried to concatenate arrays with unknown"
                    " shape %s.\n\nTwo solutions:\n"
                    "  1. Force concatenation pass"
                    " allow_unknown_chunksizes=True.\n"
                    "  2. Compute shapes with "
                    "[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
                )
            raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
    
        inds = [list(range(ndim)) for i in range(n)]
        for i, ind in enumerate(inds):
            ind[axis] = -(i + 1)
    
        uc_args = list(concat(zip(seq2, inds)))
        _, seq2 = unify_chunks(*uc_args, warn=False)
    
        bds = [a.chunks for a in seq2]
    
        chunks = (
            seq2[0].chunks[:axis]
            + (sum((bd[axis] for bd in bds), ()),)
            + seq2[0].chunks[axis + 1 :]
        )
    
        cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
    
        names = [a.name for a in seq2]
    
        name = "concatenate-" + tokenize(names, axis)
        keys = list(product([name], *[range(len(bd)) for bd in chunks]))
    
        values = [
            (names[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[1 : axis + 1]
            + (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
            + key[axis + 2 :]
            for key in keys
        ]
    
        dsk = dict(zip(keys, values))
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
    
        return Array(graph, name, chunks, meta=meta)
    
    
    def load_store_chunk(
        x: Any,
        out: Any,
        index: slice,
        lock: Any,
        return_stored: bool,
        load_stored: bool,
    ):
        """
        A function inserted in a Dask graph for storing a chunk.
    
        Parameters
        ----------
        x: array-like
            An array (potentially a NumPy one)
        out: array-like
            Where to store results.
        index: slice-like
            Where to store result from ``x`` in ``out``.
        lock: Lock-like or False
            Lock to use before writing to ``out``.
        return_stored: bool
            Whether to return ``out``.
        load_stored: bool
            Whether to return the array stored in ``out``.
            Ignored if ``return_stored`` is not ``True``.
    
        Returns
        -------
    
        If return_stored=True and load_stored=False
            out
        If return_stored=True and load_stored=True
            out[index]
        If return_stored=False and compute=False
            None
    
        Examples
        --------
    
        >>> a = np.ones((5, 6))
        >>> b = np.empty(a.shape)
        >>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
        """
        if lock:
            lock.acquire()
        try:
            if x is not None and x.size != 0:
                if is_arraylike(x):
                    out[index] = x
                else:
                    out[index] = np.asanyarray(x)
    
            if return_stored and load_stored:
                return out[index]
            elif return_stored and not load_stored:
                return out
            else:
                return None
        finally:
            if lock:
                lock.release()
    
    
    def store_chunk(
        x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
    ):
        return load_store_chunk(x, out, index, lock, return_stored, False)
    
    
    A = TypeVar("A", bound=ArrayLike)
    
    
    def load_chunk(out: A, index: slice, lock: Any) -> A:
        return load_store_chunk(None, out, index, lock, True, True)
    
    
    def insert_to_ooc(
        keys: list,
        chunks: tuple[tuple[int, ...], ...],
        out: ArrayLike,
        name: str,
        *,
        lock: Lock | bool = True,
        region: tuple[slice, ...] | slice | None = None,
        return_stored: bool = False,
        load_stored: bool = False,
    ) -> dict:
        """
        Creates a Dask graph for storing chunks from ``arr`` in ``out``.
    
        Parameters
        ----------
        keys: list
            Dask keys of the input array
        chunks: tuple
            Dask chunks of the input array
        out: array-like
            Where to store results to
        name: str
            First element of dask keys
        lock: Lock-like or bool, optional
            Whether to lock or with what (default is ``True``,
            which means a :class:`threading.Lock` instance).
        region: slice-like, optional
            Where in ``out`` to store ``arr``'s results
            (default is ``None``, meaning all of ``out``).
        return_stored: bool, optional
            Whether to return ``out``
            (default is ``False``, meaning ``None`` is returned).
        load_stored: bool, optional
            Whether to handling loading from ``out`` at the same time.
            Ignored if ``return_stored`` is not ``True``.
            (default is ``False``, meaning defer to ``return_stored``).
    
        Returns
        -------
        dask graph of store operation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")  # doctest: +SKIP
        """
    
        if lock is True:
            lock = Lock()
    
        slices = slices_from_chunks(chunks)
        if region:
            slices = [fuse_slice(region, slc) for slc in slices]
    
        if return_stored and load_stored:
            func = load_store_chunk
            args = (load_stored,)
        else:
            func = store_chunk  # type: ignore
            args = ()  # type: ignore
    
        dsk = {
            (name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
            for t, slc in zip(core.flatten(keys), slices)
        }
        return dsk
    
    
    def retrieve_from_ooc(
        keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
    ) -> dict[tuple, Any]:
        """
        Creates a Dask graph for loading stored ``keys`` from ``dsk``.
    
        Parameters
        ----------
        keys: Collection
            A sequence containing Dask graph keys to load
        dsk_pre: Mapping
            A Dask graph corresponding to a Dask Array before computation
        dsk_post: Mapping
            A Dask graph corresponding to a Dask Array after computation
    
        Examples
        --------
        >>> import dask.array as da
        >>> d = da.ones((5, 6), chunks=(2, 3))
        >>> a = np.empty(d.shape)
        >>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
        >>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()})  # doctest: +SKIP
        """
        load_dsk = {
            ("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1]  # type: ignore
            for k in keys
        }
    
        return load_dsk
    
    
    def _as_dtype(a, dtype):
        if dtype is None:
            return a
        else:
            return a.astype(dtype)
    
    
    def asarray(
        a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
    ):
        """Convert the input to a dask array.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        allow_unknown_chunksizes: bool
            Allow unknown chunksizes, such as come from converting from dask
            dataframes.  Dask.array is unable to verify that chunks line up.  If
            data comes from differently aligned sources then this can cause
            unexpected results.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(
                    stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
                )
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(a, getitem=getter_inline, **kwargs)
    
    
    def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
        """Convert the input to a dask array.
    
        Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
    
        Parameters
        ----------
        a : array-like
            Input data, in any form that can be converted to a dask array. This
            includes lists, lists of tuples, tuples, tuples of tuples, tuples of
            lists and ndarrays.
        dtype : data-type, optional
            By default, the data-type is inferred from the input data.
        order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
            Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
            ‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
            representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
            otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
        like: array-like
            Reference object to allow the creation of Dask arrays with chunks
            that are not NumPy arrays. If an array-like passed in as ``like``
            supports the ``__array_function__`` protocol, the chunk type of the
            resulting array will be defined by it. In this case, it ensures the
            creation of a Dask array compatible with that passed in via this
            argument. If ``like`` is a Dask array, the chunk type of the
            resulting array will be defined by the chunk type of ``like``.
            Requires NumPy 1.20.0 or higher.
        inline_array:
            Whether to inline the array in the resulting dask graph. For more information,
            see the documentation for ``dask.array.from_array()``.
    
        Returns
        -------
        out : dask array
            Dask array interpretation of a.
    
        Examples
        --------
        >>> import dask.array as da
        >>> import numpy as np
        >>> x = np.arange(3)
        >>> da.asanyarray(x)
        dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
    
        >>> y = [[1, 2, 3], [4, 5, 6]]
        >>> da.asanyarray(y)
        dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
    
        .. warning::
            `order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
            or is a list or tuple of `Array`'s.
        """
        if like is None:
            if isinstance(a, Array):
                return _as_dtype(a, dtype)
            elif hasattr(a, "to_dask_array"):
                return _as_dtype(a.to_dask_array(), dtype)
            elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
                return _as_dtype(asarray(a.data, order=order), dtype)
            elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
                return _as_dtype(stack(a), dtype)
            elif not isinstance(getattr(a, "shape", None), Iterable):
                a = np.asanyarray(a, dtype=dtype, order=order)
        else:
            like_meta = meta_from_array(like)
            if isinstance(a, Array):
                return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
            else:
                a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
        return from_array(
            a,
            chunks=a.shape,
            getitem=getter_inline,
            asarray=False,
            inline_array=inline_array,
        )
    
    
    def is_scalar_for_elemwise(arg):
        """
    
        >>> is_scalar_for_elemwise(42)
        True
        >>> is_scalar_for_elemwise('foo')
        True
        >>> is_scalar_for_elemwise(True)
        True
        >>> is_scalar_for_elemwise(np.array(42))
        True
        >>> is_scalar_for_elemwise([1, 2, 3])
        True
        >>> is_scalar_for_elemwise(np.array([1, 2, 3]))
        False
        >>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
        False
        >>> is_scalar_for_elemwise(np.dtype('i4'))
        True
        """
        # the second half of shape_condition is essentially just to ensure that
        # dask series / frame are treated as scalars in elemwise.
        maybe_shape = getattr(arg, "shape", None)
        shape_condition = not isinstance(maybe_shape, Iterable) or any(
            is_dask_collection(x) for x in maybe_shape
        )
    
        return (
            np.isscalar(arg)
            or shape_condition
            or isinstance(arg, np.dtype)
            or (isinstance(arg, np.ndarray) and arg.ndim == 0)
        )
    
    
    def broadcast_shapes(*shapes):
        """
        Determines output shape from broadcasting arrays.
    
        Parameters
        ----------
        shapes : tuples
            The shapes of the arguments.
    
        Returns
        -------
        output_shape : tuple
    
        Raises
        ------
        ValueError
            If the input shapes cannot be successfully broadcast together.
        """
        if len(shapes) == 1:
            return shapes[0]
        out = []
        for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
            if np.isnan(sizes).any():
                dim = np.nan
            else:
                dim = 0 if 0 in sizes else np.max(sizes).item()
            if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
                raise ValueError(
                    "operands could not be broadcast together with "
                    "shapes {}".format(" ".join(map(str, shapes)))
                )
            out.append(dim)
        return tuple(reversed(out))
    
    
    def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
        """Apply an elementwise ufunc-like function blockwise across arguments.
    
        Like numpy ufuncs, broadcasting rules are respected.
    
        Parameters
        ----------
        op : callable
            The function to apply. Should be numpy ufunc-like in the parameters
            that it accepts.
        *args : Any
            Arguments to pass to `op`. Non-dask array-like objects are first
            converted to dask arrays, then all arrays are broadcast together before
            applying the function blockwise across all arguments. Any scalar
            arguments are passed as-is following normal numpy ufunc behavior.
        out : dask array, optional
            If out is a dask.array then this overwrites the contents of that array
            with the result.
        where : array_like, optional
            An optional boolean mask marking locations where the ufunc should be
            applied. Can be a scalar, dask array, or any other array-like object.
            Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
            for more information.
        dtype : dtype, optional
            If provided, overrides the output array dtype.
        name : str, optional
            A unique key name to use when building the backing dask graph. If not
            provided, one will be automatically generated based on the input
            arguments.
    
        Examples
        --------
        >>> elemwise(add, x, y)  # doctest: +SKIP
        >>> elemwise(sin, x)  # doctest: +SKIP
        >>> elemwise(sin, x, out=dask_array)  # doctest: +SKIP
    
        See Also
        --------
        blockwise
        """
        if kwargs:
            raise TypeError(
                f"{op.__name__} does not take the following keyword arguments "
                f"{sorted(kwargs)}"
            )
    
        out = _elemwise_normalize_out(out)
        where = _elemwise_normalize_where(where)
        args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
    
        shapes = []
        for arg in args:
            shape = getattr(arg, "shape", ())
            if any(is_dask_collection(x) for x in shape):
                # Want to exclude Delayed shapes and dd.Scalar
                shape = ()
            shapes.append(shape)
        if isinstance(where, Array):
            shapes.append(where.shape)
        if isinstance(out, Array):
            shapes.append(out.shape)
    
        shapes = [s if isinstance(s, Iterable) else () for s in shapes]
        out_ndim = len(
            broadcast_shapes(*shapes)
        )  # Raises ValueError if dimensions mismatch
        expr_inds = tuple(range(out_ndim))[::-1]
    
        if dtype is not None:
            need_enforce_dtype = True
        else:
            # We follow NumPy's rules for dtype promotion, which special cases
            # scalars and 0d ndarrays (which it considers equivalent) by using
            # their values to compute the result dtype:
            # https://github.com/numpy/numpy/issues/6240
            # We don't inspect the values of 0d dask arrays, because these could
            # hold potentially very expensive calculations. Instead, we treat
            # them just like other arrays, and if necessary cast the result of op
            # to match.
            vals = [
                (
                    np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
                    if not is_scalar_for_elemwise(a)
                    else a
                )
                for a in args
            ]
            try:
                dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
            except Exception:
                return NotImplemented
            need_enforce_dtype = any(
                not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
            )
    
        if not name:
            name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
    
        blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
    
        if where is not True:
            blockwise_kwargs["elemwise_where_function"] = op
            op = _elemwise_handle_where
            args.extend([where, out])
    
        if need_enforce_dtype:
            blockwise_kwargs["enforce_dtype"] = dtype
            blockwise_kwargs["enforce_dtype_function"] = op
            op = _enforce_dtype
    
        result = blockwise(
            op,
            expr_inds,
            *concat(
                (a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
                for a in args
            ),
            **blockwise_kwargs,
        )
    
        return handle_out(out, result)
    
    
    def _elemwise_normalize_where(where):
        if where is True:
            return True
        elif where is False or where is None:
            return False
        return asarray(where)
    
    
    def _elemwise_handle_where(*args, **kwargs):
        function = kwargs.pop("elemwise_where_function")
        *args, where, out = args
        if hasattr(out, "copy"):
            out = out.copy()
        return function(*args, where=where, out=out, **kwargs)
    
    
    def _elemwise_normalize_out(out):
        if isinstance(out, tuple):
            if len(out) == 1:
                out = out[0]
            elif len(out) > 1:
                raise NotImplementedError("The out parameter is not fully supported")
            else:
                out = None
        if not (out is None or isinstance(out, Array)):
            raise NotImplementedError(
                f"The out parameter is not fully supported."
                f" Received type {type(out).__name__}, expected Dask Array"
            )
        return out
    
    
    def handle_out(out, result):
        """Handle out parameters
    
        If out is a dask.array then this overwrites the contents of that array with
        the result
        """
        out = _elemwise_normalize_out(out)
        if isinstance(out, Array):
            if out.shape != result.shape:
                raise ValueError(
                    "Mismatched shapes between result and out parameter. "
                    "out=%s, result=%s" % (str(out.shape), str(result.shape))
                )
            out._chunks = result.chunks
            out.dask = result.dask
            out._meta = result._meta
            out._name = result.name
            return out
        else:
            return result
    
    
    def _enforce_dtype(*args, **kwargs):
        """Calls a function and converts its result to the given dtype.
    
        The parameters have deliberately been given unwieldy names to avoid
        clashes with keyword arguments consumed by blockwise
    
        A dtype of `object` is treated as a special case and not enforced,
        because it is used as a dummy value in some places when the result will
        not be a block in an Array.
    
        Parameters
        ----------
        enforce_dtype : dtype
            Result dtype
        enforce_dtype_function : callable
            The wrapped function, which will be passed the remaining arguments
        """
        dtype = kwargs.pop("enforce_dtype")
        function = kwargs.pop("enforce_dtype_function")
    
        result = function(*args, **kwargs)
        if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
            if not np.can_cast(result, dtype, casting="same_kind"):
                raise ValueError(
                    "Inferred dtype from function %r was %r "
                    "but got %r, which can't be cast using "
                    "casting='same_kind'"
                    % (funcname(function), str(dtype), str(result.dtype))
                )
            if np.isscalar(result):
                # scalar astype method doesn't take the keyword arguments, so
                # have to convert via 0-dimensional array and back.
                result = result.astype(dtype)
            else:
                try:
                    result = result.astype(dtype, copy=False)
                except TypeError:
                    # Missing copy kwarg
                    result = result.astype(dtype)
        return result
    
    
    def broadcast_to(x, shape, chunks=None, meta=None):
        """Broadcast an array to a new shape.
    
        Parameters
        ----------
        x : array_like
            The array to broadcast.
        shape : tuple
            The shape of the desired array.
        chunks : tuple, optional
            If provided, then the result will use these chunks instead of the same
            chunks as the source array. Setting chunks explicitly as part of
            broadcast_to is more efficient than rechunking afterwards. Chunks are
            only allowed to differ from the original shape along dimensions that
            are new on the result or have size 1 the input array.
        meta : empty ndarray
            empty ndarray created with same NumPy backend, ndim and dtype as the
            Dask Array being created (overrides dtype)
    
        Returns
        -------
        broadcast : dask array
    
        See Also
        --------
        :func:`numpy.broadcast_to`
        """
        x = asarray(x)
        shape = tuple(shape)
    
        if meta is None:
            meta = meta_from_array(x)
    
        if x.shape == shape and (chunks is None or chunks == x.chunks):
            return x
    
        ndim_new = len(shape) - x.ndim
        if ndim_new < 0 or any(
            new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
        ):
            raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
    
        if chunks is None:
            chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
                bd if old > 1 else (new,)
                for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
            )
        else:
            chunks = normalize_chunks(
                chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
            )
            for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
                if old_bd != new_bd and old_bd != (1,):
                    raise ValueError(
                        "cannot broadcast chunks %s to chunks %s: "
                        "new chunks must either be along a new "
                        "dimension or a dimension of size 1" % (x.chunks, chunks)
                    )
    
        name = "broadcast_to-" + tokenize(x, shape, chunks)
        dsk = {}
    
        enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
        for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
            old_index = tuple(
                0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
            )
            old_key = (x.name,) + old_index
            new_key = (name,) + new_index
            dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
    
        graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
        return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
    
    
    @derived_from(np)
    def broadcast_arrays(*args, subok=False):
        subok = bool(subok)
    
        to_array = asanyarray if subok else asarray
        args = tuple(to_array(e) for e in args)
    
        # Unify uneven chunking
        inds = [list(reversed(range(x.ndim))) for x in args]
        uc_args = concat(zip(args, inds))
        _, args = unify_chunks(*uc_args, warn=False)
    
        shape = broadcast_shapes(*(e.shape for e in args))
        chunks = broadcast_chunks(*(e.chunks for e in args))
    
        if NUMPY_GE_200:
            result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
        else:
            result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
    
        return result
    
    
    def offset_func(func, offset, *args):
        """Offsets inputs by offset
    
        >>> double = lambda x: x * 2
        >>> f = offset_func(double, (10,))
        >>> f(1)
        22
        >>> f(300)
        620
        """
    
        def _offset(*args):
            args2 = list(map(add, args, offset))
            return func(*args2)
    
        with contextlib.suppress(Exception):
            _offset.__name__ = "offset_" + func.__name__
    
        return _offset
    
    
    def chunks_from_arrays(arrays):
        """Chunks tuple from nested list of arrays
    
        >>> x = np.array([1, 2])
        >>> chunks_from_arrays([x, x])
        ((2, 2),)
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x], [x]])
        ((1, 1), (2,))
    
        >>> x = np.array([[1, 2]])
        >>> chunks_from_arrays([[x, x]])
        ((1,), (2, 2))
    
        >>> chunks_from_arrays([1, 1])
        ((1, 1),)
        """
        if not arrays:
            return ()
        result = []
        dim = 0
    
        def shape(x):
            try:
                return x.shape if x.shape else (1,)
            except AttributeError:
                return (1,)
    
        while isinstance(arrays, (list, tuple)):
>           result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E           IndexError: tuple index out of range

../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: IndexError

Check warning on line 0 in distributed.tests.test_stress

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

11 out of 12 runs failed: test_stress_creation_and_deletion (distributed.tests.test_stress)

artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 8s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 12s]
artifacts/windows-latest-3.10-default-ci1/pytest.xml [took 7s]
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 8s]
artifacts/windows-latest-3.12-default-ci1/pytest.xml [took 8s]
Raw output
AssertionError: assert 'round-bb724227e9bb17988ea1ad61e539e2b8' == 8000884.93
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33099', workers: 0, cores: 0, tasks: 0>

    @pytest.mark.slow
    @gen_cluster(
        nthreads=[],
        client=True,
        scheduler_kwargs={"allowed_failures": 100_000},
    )
    async def test_stress_creation_and_deletion(c, s):
        # Assertions are handled by the validate mechanism in the scheduler
        pytest.importorskip("numpy")
        da = pytest.importorskip("dask.array")
    
        rng = da.random.RandomState(0)
        x = rng.random(size=(2000, 2000), chunks=(100, 100))
        y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
        z = c.persist(y)
    
        async def create_and_destroy_worker(delay):
            start = time()
            while time() < start + 5:
                async with Worker(s.address, nthreads=2) as n:
                    await asyncio.sleep(delay)
    
        await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))
    
        async with Worker(s.address, nthreads=2):
>           assert await c.compute(z) == 8000884.93
E           AssertionError: assert 'round-bb724227e9bb17988ea1ad61e539e2b8' == 8000884.93

distributed/tests/test_stress.py:125: AssertionError

Check warning on line 0 in distributed.cli.tests.test_dask_spec

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

1 out of 12 runs failed: test_errors (distributed.cli.tests.test_dask_spec)

artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 10s]
Raw output
subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'spec', '--spec', '{"foo": "bar"}', '--spec-file', 'foo.yaml']' timed out after 10 seconds
def test_errors():
>       with popen(
            [
                "dask",
                "spec",
                "--spec",
                '{"foo": "bar"}',
                "--spec-file",
                "foo.yaml",
            ],
            capture_output=True,
        ) as proc:

distributed/cli/tests/test_dask_spec.py:73: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:142: in __exit__
    next(self.gen)
distributed/utils_test.py:1204: in popen
    _terminate_process(proc, terminate_timeout)
distributed/utils_test.py:1130: in _terminate_process
    proc.communicate(timeout=terminate_timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.10/subprocess.py:1154: in communicate
    stdout, stderr = self._communicate(input, endtime, timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.10/subprocess.py:2022: in _communicate
    self._check_timeout(endtime, orig_timeout, stdout, stderr)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <Popen: returncode: -9 args: ['/home/runner/miniconda3/envs/dask-distributed...>
endtime = 338.631460995, orig_timeout = 10
stdout_seq = [b'Exception ignored in atexit callback: <function _close_global_client at 0x7f23fc83a4d0>\nTraceback (most recent cal...ackages/coverage/collector.py", line 252, in lock_data\n', b'    self.data_lock.acquire()\nKeyboardInterrupt: \n', b'']
stderr_seq = None, skip_check_and_raise = False

    def _check_timeout(self, endtime, orig_timeout, stdout_seq, stderr_seq,
                       skip_check_and_raise=False):
        """Convenience for checking if a timeout has expired."""
        if endtime is None:
            return
        if skip_check_and_raise or _time() > endtime:
>           raise TimeoutExpired(
                    self.args, orig_timeout,
                    output=b''.join(stdout_seq) if stdout_seq else None,
                    stderr=b''.join(stderr_seq) if stderr_seq else None)
E           subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'spec', '--spec', '{"foo": "bar"}', '--spec-file', 'foo.yaml']' timed out after 10 seconds

../../../miniconda3/envs/dask-distributed/lib/python3.10/subprocess.py:1198: TimeoutExpired

Check warning on line 0 in distributed.cli.tests.test_dask_worker

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

1 out of 12 runs failed: test_nanny_worker_port_range_too_many_workers_raises (distributed.cli.tests.test_dask_worker)

artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 10s]
Raw output
subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'worker', 'tcp://127.0.0.1:34913', '--nworkers', '3', '--host', '127.0.0.1', '--worker-port', '9684:9685', '--nanny-port', '9686:9687', '--no-dashboard']' timed out after 10 seconds
s = <Scheduler 'tcp://127.0.0.1:34913', workers: 0, cores: 0, tasks: 0>

    @gen_cluster(nthreads=[])
    async def test_nanny_worker_port_range_too_many_workers_raises(s):
>       with popen(
            [
                "dask",
                "worker",
                s.address,
                "--nworkers",
                "3",
                "--host",
                "127.0.0.1",
                "--worker-port",
                "9684:9685",
                "--nanny-port",
                "9686:9687",
                "--no-dashboard",
            ],
            capture_output=True,
        ) as worker:

distributed/cli/tests/test_dask_worker.py:244: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../miniconda3/envs/dask-distributed/lib/python3.11/contextlib.py:144: in __exit__
    next(self.gen)
distributed/utils_test.py:1204: in popen
    _terminate_process(proc, terminate_timeout)
distributed/utils_test.py:1130: in _terminate_process
    proc.communicate(timeout=terminate_timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.11/subprocess.py:1209: in communicate
    stdout, stderr = self._communicate(input, endtime, timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.11/subprocess.py:2116: in _communicate
    self._check_timeout(endtime, orig_timeout, stdout, stderr)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <Popen: returncode: -9 args: ['/home/runner/miniconda3/envs/dask-distributed...>
endtime = 533.565681173, orig_timeout = 10
stdout_seq = [b'Exception ignored in atexit callback: <function close_clusters at 0x7f704b422200>\nTraceback (most recent call last...verage/collector.py", line 254, in unlock_data\n', b'    def unlock_data(self) -> None:\n\nKeyboardInterrupt: \n', b'']
stderr_seq = None, skip_check_and_raise = False

    def _check_timeout(self, endtime, orig_timeout, stdout_seq, stderr_seq,
                       skip_check_and_raise=False):
        """Convenience for checking if a timeout has expired."""
        if endtime is None:
            return
        if skip_check_and_raise or _time() > endtime:
>           raise TimeoutExpired(
                    self.args, orig_timeout,
                    output=b''.join(stdout_seq) if stdout_seq else None,
                    stderr=b''.join(stderr_seq) if stderr_seq else None)
E           subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'worker', 'tcp://127.0.0.1:34913', '--nworkers', '3', '--host', '127.0.0.1', '--worker-port', '9684:9685', '--nanny-port', '9686:9687', '--no-dashboard']' timed out after 10 seconds

../../../miniconda3/envs/dask-distributed/lib/python3.11/subprocess.py:1253: TimeoutExpired

Check warning on line 0 in distributed.comm.tests.test_comms

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

3 out of 12 runs failed: test_tls_comm_closed_implicit[tornado] (distributed.comm.tests.test_comms)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
ssl.SSLError: [SYS] unknown error (_ssl.c:2580)
tcp = <module 'distributed.comm.tcp' from '/home/runner/work/distributed/distributed/distributed/comm/tcp.py'>

    @gen_test()
    async def test_tls_comm_closed_implicit(tcp):
>       await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs)

distributed/comm/tests/test_comms.py:777: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/comm/tests/test_comms.py:763: in check_comm_closed_implicit
    await comm.read()
distributed/comm/tcp.py:225: in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:422: in read_bytes
    self._try_inline_read()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:836: in _try_inline_read
    pos = self._read_to_buffer_loop()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:750: in _read_to_buffer_loop
    if self._read_to_buffer() == 0:
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:861: in _read_to_buffer
    bytes_read = self.read_from_fd(buf)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:1552: in read_from_fd
    return self.socket.recv_into(buf, len(buf))
../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1314: in recv_into
    return self.read(nbytes, buffer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <ssl.SSLSocket [closed] fd=-1, family=2, type=1, proto=0>, len = 65536
buffer = bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x...0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')

    def read(self, len=1024, buffer=None):
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""
    
        self._checkClosed()
        if self._sslobj is None:
            raise ValueError("Read on closed or unwrapped SSL socket.")
        try:
            if buffer is not None:
>               return self._sslobj.read(len, buffer)
E               ssl.SSLError: [SYS] unknown error (_ssl.c:2580)

../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1166: SSLError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[p2p-tasks] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 5eac77ca35612a3e8c9f0e275e3e0020 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '5eac77ca35612a3e8c9f0e275e3e0020'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and … await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64012', workers: 0, cores: 0, tasks: 0>
config_value = 'tasks', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:64013', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64016', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.25269457, 0.21923049, 0.09456582, 0.41318158, 0.26615119,
        0.53192806, 0.44881814, 0.80442171, 0.7545..., 0.03274378, 0.17641201, 0.04475437, 0.56749567,
        0.69504577, 0.59694585, 0.23471839, 0.63257509, 0.62788706]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 5eac77ca35612a3e8c9f0e275e3e0020 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[p2p-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7ed88ee600fafb6fa95cf3f00f42eaf5 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '7ed88ee600fafb6fa95cf3f00f42eaf5'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and … = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64026', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:64027', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64030', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.70050953, 0.59988581, 0.2873489 , 0.02919635, 0.75044654,
        0.34847269, 0.00570634, 0.79608935, 0.0933..., 0.19892204, 0.22475293, 0.7469204 , 0.13692289,
        0.71068667, 0.03039515, 0.84743146, 0.23023854, 0.91232044]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 7ed88ee600fafb6fa95cf3f00f42eaf5 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[p2p-None] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P c27782cfc676cfba0abb18c164e6cb85 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'c27782cfc676cfba0abb18c164e6cb85'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …n = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64040', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:64041', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64044', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.40426865, 0.81855948, 0.50957116, 0.44965013, 0.25501785,
        0.20329874, 0.75418044, 0.10925094, 0.2364..., 0.76864058, 0.805811  , 0.93629147, 0.94224933,
        0.3873398 , 0.336608  , 0.01348985, 0.7104567 , 0.8208412 ]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P c27782cfc676cfba0abb18c164e6cb85 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[None-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 35f3712b5c49a47f3324d11460012f87 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '35f3712b5c49a47f3324d11460012f87'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …n = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64069', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = None
ws = (<Worker 'tcp://127.0.0.1:64070', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64073', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.03641802, 0.79395381, 0.07845067, 0.75119799, 0.34884585,
        0.86607563, 0.02201857, 0.38369563, 0.9759..., 0.26077311, 0.62562275, 0.61415159, 0.22621115,
        0.02064385, 0.83236264, 0.008683  , 0.42213654, 0.03328439]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 35f3712b5c49a47f3324d11460012f87 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[None-None] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P e95e40ae5ccb7f71758613ef093cea03 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'e95e40ae5ccb7f71758613ef093cea03'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …un = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64083', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = None
ws = (<Worker 'tcp://127.0.0.1:64084', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64087', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.3038855 , 0.72204286, 0.59516751, 0.62391027, 0.57077081,
        0.18230923, 0.15963029, 0.88508962, 0.5943..., 0.7712572 , 0.77847078, 0.52208647, 0.62879176,
        0.45921446, 0.22814207, 0.31759614, 0.60266988, 0.17790265]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P e95e40ae5ccb7f71758613ef093cea03 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_heuristic[new0-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 031e1a73eb9bc8ea0efa30aa81c3dd68 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '031e1a73eb9bc8ea0efa30aa81c3dd68'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …
distributed\shuffle\_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed\utils.py:439: in sync
    raise error
distributed\utils.py:413: in f
    result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
    value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64097', workers: 0, cores: 0, tasks: 0>
a = array([[0.47181678, 0.47249156, 0.78796815, ..., 0.85820206, 0.9811494 ,
        0.55305409],
       [0.21465893, 0.37...21,
        0.94585984],
       [0.36259846, 0.45029292, 0.94920195, ..., 0.3275337 , 0.03679922,
        0.90015213]])
b = <Worker 'tcp://127.0.0.1:64101', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((1, 1, 1, 1, 1, 1, ...), (100,)), expected_algorithm = 'p2p'

    @pytest.mark.parametrize(
        ["new", "expected_algorithm"],
        [
            # All-to-all rechunking defaults to P2P
            (((1,) * 100, (100,)), "p2p"),
            # Localized rechunking defaults to tasks
            (((50, 50), (2,) * 50), "tasks"),
            # Less local rechunking first defaults to tasks,
            (((25, 25, 25, 25), (4,) * 25), "tasks"),
            # then switches to p2p
            (((10,) * 10, (10,) * 10), "p2p"),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
        a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
        x = da.from_array(a, chunks=(100, 1))
        x2 = rechunk(x, chunks=new)
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        else:
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:239: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 031e1a73eb9bc8ea0efa30aa81c3dd68 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_heuristic[new3-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 5b9d4830bd1782a20359d383dac64ddc failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '5b9d4830bd1782a20359d383dac64ddc'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed\utils.py:439: in sync
    raise error
distributed\utils.py:413: in f
    result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
    value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64151', workers: 0, cores: 0, tasks: 0>
a = array([[0.02218022, 0.54394215, 0.75067916, ..., 0.52673533, 0.10159338,
        0.32048149],
       [0.64108086, 0.21...08,
        0.18661381],
       [0.77971001, 0.88125599, 0.40857319, ..., 0.48985381, 0.50061773,
        0.82015301]])
b = <Worker 'tcp://127.0.0.1:64155', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((10, 10, 10, 10, 10, 10, ...), (10, 10, 10, 10, 10, 10, ...))
expected_algorithm = 'p2p'

    @pytest.mark.parametrize(
        ["new", "expected_algorithm"],
        [
            # All-to-all rechunking defaults to P2P
            (((1,) * 100, (100,)), "p2p"),
            # Localized rechunking defaults to tasks
            (((50, 50), (2,) * 50), "tasks"),
            # Less local rechunking first defaults to tasks,
            (((25, 25, 25, 25), (4,) * 25), "tasks"),
            # then switches to p2p
            (((10,) * 10, (10,) * 10), "p2p"),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
        a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
        x = da.from_array(a, chunks=(100, 1))
        x2 = rechunk(x, chunks=new)
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        else:
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed\shuffle\tests\test_rechunk.py:239: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 5b9d4830bd1782a20359d383dac64ddc failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_cull_p2p_rechunk_independent_partitions (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
assert 57 < (228 / 4)
 +  where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n)
 +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
 +  and   228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n)
 +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64166', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64167', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64170', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[[4.01332956e-01, 5.05156722e-01, 2.80394807e-02, 5.91861613e-01,
         6.20283724e-01, 3.06173020e-02, 5.19...1,
         3.34170490e-01, 3.41256630e-01, 6.56954053e-01, 5.49743569e-01,
         4.71146730e-01, 8.40362257e-01]]])
x = dask.array<array, shape=(10, 10, 10), dtype=float64, chunksize=(1, 5, 1), chunktype=numpy.ndarray>
new = (5, 1, -1)

    @gen_cluster(client=True)
    async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws):
        a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10))
        x = da.from_array(a, chunks=(1, 5, 1))
        new = (5, 1, -1)
        rechunked = rechunk(x, chunks=new, method="p2p")
        (dsk,) = dask.optimize(rechunked)
        culled = rechunked[:5, :2]
        (dsk_culled,) = dask.optimize(culled)
    
        # The culled graph requires only 1/2 of the input tasks
        n_inputs = len(
            [1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")]
        )
        n_culled_inputs = len(
            [
                1
                for key in dsk_culled.dask.get_all_dependencies()
                if key[0].startswith("array-")
            ]
        )
        assert n_culled_inputs == n_inputs / 4
        # The culled graph should also have less than 1/4 the tasks
>       assert len(dsk_culled.dask) < len(dsk.dask) / 4
E       assert 57 < (228 / 4)
E        +  where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n)
E        +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
E        +  and   228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n)
E        +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask

distributed\shuffle\tests\test_rechunk.py:265: AssertionError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_expand (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7df54f410e79f8489793c71b35ec9a5d failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '7df54f410e79f8489793c71b35ec9a5d'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …ger
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
>           yield

distributed\shuffle\_core.py:523: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\shuffle\_rechunk.py:170: in rechunk_transfer
    return get_worker_plugin().add_partition(
distributed\shuffle\_worker_plugin.py:348: in add_partition
    shuffle_run = self.get_or_create_shuffle(id)
distributed\shuffle\_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed\utils.py:439: in sync
    raise error
distributed\utils.py:413: in f
    result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
    value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64264', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64265', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64268', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.4588262 , 0.29882582, 0.71928598, 0.22034295, 0.98812948,
        0.67941824, 0.25921056, 0.62241972, 0.3004..., 0.28284222, 0.16402205, 0.55920114, 0.6132267 ,
        0.17587811, 0.09622778, 0.36399375, 0.04524015, 0.74855889]])
x = dask.array<array, shape=(10, 10), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
y = dask.array<rechunk-p2p, shape=(10, 10), dtype=float64, chunksize=(3, 3), chunktype=numpy.ndarray>

    @gen_cluster(client=True)
    async def test_rechunk_expand(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_expand
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(5, 5))
        y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
>       assert np.all(await c.compute(y) == a)

distributed\shuffle\tests\test_rechunk.py:377: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 7df54f410e79f8489793c71b35ec9a5d failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_expand2 (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 6e5cf1da37a31a6c56bc50f1842218f5 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '6e5cf1da37a31a6c56bc50f1842218f5'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …rors(id: ShuffleId) -> Iterator[None]:
        try:
>           yield

distributed\shuffle\_core.py:523: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\shuffle\_rechunk.py:170: in rechunk_transfer
    return get_worker_plugin().add_partition(
distributed\shuffle\_worker_plugin.py:348: in add_partition
    shuffle_run = self.get_or_create_shuffle(id)
distributed\shuffle\_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed\utils.py:439: in sync
    raise error
distributed\utils.py:413: in f
    result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
    value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64279', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64280', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64283', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = 3, b = 2
orig = array([[0.57459479, 0.24261869, 0.6206557 ],
       [0.97627947, 0.65759   , 0.41902488],
       [0.08888248, 0.69644868, 0.30588096]])

    @gen_cluster(client=True)
    async def test_rechunk_expand2(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_expand2
        """
        (a, b) = (3, 2)
        orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
        for off, off2 in product(range(1, a - 1), range(1, a - 1)):
            old = ((a - off, off),) * b
            x = da.from_array(orig, chunks=old)
            new = ((a - off2, off2),) * b
            assert np.all(await c.compute(x.rechunk(chunks=new, method="p2p")) == orig)
            if a - off - off2 > 0:
                new = ((off, a - off2 - off, off2),) * b
>               y = await c.compute(x.rechunk(chunks=new, method="p2p"))

distributed\shuffle\tests\test_rechunk.py:396: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 6e5cf1da37a31a6c56bc50f1842218f5 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_unknown_from_pandas (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P db9c09cc2dbfe69e4bcd9524cb9be9d7 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'db9c09cc2dbfe69e4bcd9524cb9be9d7'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and … get_or_create_shuffle
    return sync(
distributed\utils.py:439: in sync
    raise error
distributed\utils.py:413: in f
    result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
    value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64576', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64577', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64580', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
pd = <module 'pandas' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\pandas\\__init__.py'>
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
arr = array([[-1.21493258e+00,  8.70655313e-01,  8.33382075e-01,
         8.51173412e-01,  1.41451004e+00, -1.98203709e-01,
... 1.52552622e+00, -7.93934769e-01,
        -3.30697106e+00,  6.96784483e-01, -1.59570089e-02,
        -1.61997549e+00]])

    @gen_cluster(client=True)
    async def test_rechunk_unknown_from_pandas(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas
        """
        pd = pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
    
        arr = np.random.default_rng().standard_normal((50, 10))
        x = dd.from_pandas(pd.DataFrame(arr), 2).values
        result = x.rechunk((None, (5, 5)), method="p2p")
        assert np.isnan(x.chunks[0]).all()
        assert np.isnan(result.chunks[0]).all()
        assert result.chunks[1] == (5, 5)
        expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk(
            (None, (5, 5)), method="p2p"
        )
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:706: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P db9c09cc2dbfe69e4bcd9524cb9be9d7 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x0-chunks0] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '53a5056f80ef44da75580c3c1aff19ce'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …n _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64605', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:64606', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64609', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x1-chunks1] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '53a5056f80ef44da75580c3c1aff19ce'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64619', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:64620', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64623', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x2-chunks2] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '53a5056f80ef44da75580c3c1aff19ce'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …fresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64633', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:64634', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64637', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x3-chunks3] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '103aea8fc8e7486296d569fc1f0805a8'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …n _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64647', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:64648', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64651', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x4-chunks4] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '103aea8fc8e7486296d569fc1f0805a8'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64662', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:64663', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64666', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x5-chunks5] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P ba094fb6021231c810a48e86cbce8897 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'ba094fb6021231c810a48e86cbce8897'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …fresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64677', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:64678', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64681', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P ba094fb6021231c810a48e86cbce8897 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x6-chunks6] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P b5a56a723b010b3ef762d5c14fe583f9 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed\shuffle\_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'b5a56a723b010b3ef762d5c14fe583f9'

distributed\shuffle\_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and …n _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
    result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed\shuffle\_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64697', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:64698', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64701', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed\shuffle\tests\test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
    raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
    self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P b5a56a723b010b3ef762d5c14fe583f9 failed during transfer phase

distributed\shuffle\_core.py:531: RuntimeError