Remove recursion in task spec (#8920) #2504
46 fail, 110 skipped, 3 974 pass in 10h 12m 37s
25 files 25 suites 10h 12m 37s ⏱️
4 130 tests 3 974 ✅ 110 💤 46 ❌
47 692 runs 45 130 ✅ 2 121 💤 441 ❌
Results for commit 750cb91.
Annotations
Check warning on line 0 in distributed.tests.test_client
github-actions / Unit Test Results
11 out of 12 runs failed: test_persist_async (distributed.tests.test_client)
artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39653', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:37647', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:44433', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_persist_async(c, s, a, b):
pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
x = da.ones((10, 10), chunks=(5, 10))
y = 2 * (x + 1)
assert len(y.dask) == 6
yy = c.persist(y)
assert len(y.dask) == 6
assert len(yy.dask) == 2
assert all(isinstance(v, Future) for v in yy.dask.values())
assert yy.__dask_keys__() == y.__dask_keys__()
g, h = c.compute([y, yy])
> gg, hh = await c.gather([g, h])
distributed/tests/test_client.py:2565:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:2427: in _gather
raise exception.with_traceback(traceback)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1328: in finalize
return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5496: in concatenate3
chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: in chunks_from_arrays
result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import contextlib
import math
import operator
import os
import pickle
import re
import sys
import traceback
import uuid
import warnings
from bisect import bisect
from collections import defaultdict
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial, reduce, wraps
from itertools import product, zip_longest
from numbers import Integral, Number
from operator import add, mul
from threading import Lock
from typing import Any, Literal, TypeVar, Union, cast
import numpy as np
from numpy.typing import ArrayLike
from packaging.version import Version
from tlz import accumulate, concat, first, groupby, partition
from tlz.curried import pluck
from toolz import frequencies
from dask import compute, config, core
from dask.array import chunk
from dask.array.chunk import getitem
from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
# Keep einsum_lookup and tensordot_lookup here for backwards compatibility
from dask.array.dispatch import ( # noqa: F401
concatenate_lookup,
einsum_lookup,
tensordot_lookup,
)
from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
from dask.array.utils import compute_meta, meta_from_array
from dask.base import (
DaskMethodsMixin,
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
from dask.blockwise import blockwise as core_blockwise
from dask.blockwise import broadcast_dimensions
from dask.context import globalmethod
from dask.core import quote
from dask.delayed import Delayed, delayed
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.layers import ArraySliceDep, reshapelist
from dask.sizeof import sizeof
from dask.typing import Graph, Key, NestedKeys
from dask.utils import (
IndexCallable,
SerializableLock,
cached_cumsum,
cached_property,
concrete,
derived_from,
format_bytes,
funcname,
has_keyword,
is_arraylike,
is_dataframe_like,
is_index_like,
is_integer,
is_series_like,
maybe_pluralize,
ndeepmap,
ndimlist,
parse_bytes,
typename,
)
from dask.widgets import get_template
T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]]
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
unknown_chunk_message = (
"\n\n"
"A possible solution: "
"https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
"Summary: to compute chunks sizes, use\n\n"
" x.compute_chunk_sizes() # for Dask Array `x`\n"
" ddf.to_dask_array(lengths=True) # for Dask DataFrame `ddf`"
)
class PerformanceWarning(Warning):
"""A warning given when bad chunking may cause poor performance"""
def getter(a, b, asarray=True, lock=None):
if isinstance(b, tuple) and any(x is None for x in b):
b2 = tuple(x for x in b if x is not None)
b3 = tuple(
None if x is None else slice(None, None)
for x in b
if not isinstance(x, Integral)
)
return getter(a, b2, asarray=asarray, lock=lock)[b3]
if lock:
lock.acquire()
try:
c = a[b]
# Below we special-case `np.matrix` to force a conversion to
# `np.ndarray` and preserve original Dask behavior for `getter`,
# as for all purposes `np.matrix` is array-like and thus
# `is_arraylike` evaluates to `True` in that case.
if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
c = np.asarray(c)
finally:
if lock:
lock.release()
return c
def getter_nofancy(a, b, asarray=True, lock=None):
"""A simple wrapper around ``getter``.
Used to indicate to the optimization passes that the backend doesn't
support fancy indexing.
"""
return getter(a, b, asarray=asarray, lock=lock)
def getter_inline(a, b, asarray=True, lock=None):
"""A getter function that optimizations feel comfortable inlining
Slicing operations with this function may be inlined into a graph, such as
in the following rewrite
**Before**
>>> a = x[:10] # doctest: +SKIP
>>> b = a + 1 # doctest: +SKIP
>>> c = a * 2 # doctest: +SKIP
**After**
>>> b = x[:10] + 1 # doctest: +SKIP
>>> c = x[:10] * 2 # doctest: +SKIP
This inlining can be relevant to operations when running off of disk.
"""
return getter(a, b, asarray=asarray, lock=lock)
from dask.array.optimization import fuse_slice, optimize
# __array_function__ dict for mapping aliases and mismatching names
_HANDLED_FUNCTIONS = {}
def implements(*numpy_functions):
"""Register an __array_function__ implementation for dask.array.Array
Register that a function implements the API of a NumPy function (or several
NumPy functions in case of aliases) which is handled with
``__array_function__``.
Parameters
----------
\\*numpy_functions : callables
One or more NumPy functions that are handled by ``__array_function__``
and will be mapped by `implements` to a `dask.array` function.
"""
def decorator(dask_func):
for numpy_function in numpy_functions:
_HANDLED_FUNCTIONS[numpy_function] = dask_func
return dask_func
return decorator
def _should_delegate(self, other) -> bool:
"""Check whether Dask should delegate to the other.
This implementation follows NEP-13:
https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
return True
elif (
hasattr(other, "__array_ufunc__")
and not is_valid_array_chunk(other)
# don't delegate to our own parent classes
and not isinstance(self, type(other))
and type(self) is not type(other)
):
return True
return False
def check_if_handled_given_other(f):
"""Check if method is handled by Dask given type of other
Ensures proper deferral to upcast types in dunder operations without
assuming unknown types are automatically downcast types.
"""
@wraps(f)
def wrapper(self, other):
if _should_delegate(self, other):
return NotImplemented
else:
return f(self, other)
return wrapper
def slices_from_chunks(chunks):
"""Translate chunks tuple to a set of slices in product order
>>> slices_from_chunks(((2, 2), (3, 3, 3))) # doctest: +NORMALIZE_WHITESPACE
[(slice(0, 2, None), slice(0, 3, None)),
(slice(0, 2, None), slice(3, 6, None)),
(slice(0, 2, None), slice(6, 9, None)),
(slice(2, 4, None), slice(0, 3, None)),
(slice(2, 4, None), slice(3, 6, None)),
(slice(2, 4, None), slice(6, 9, None))]
"""
cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
slices = [
[slice(s, s + dim) for s, dim in zip(starts, shapes)]
for starts, shapes in zip(cumdims, chunks)
]
return list(product(*slices))
def graph_from_arraylike(
arr, # Any array-like which supports slicing
chunks,
shape,
name,
getitem=getter,
lock=False,
asarray=True,
dtype=None,
inline_array=False,
) -> HighLevelGraph:
"""
HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
array-like is embedded into every task.,
If ``inline_array`` is False, this inserts the array-like as a standalone value in
a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
to it.
>>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True)) # doctest: +SKIP
{(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
(arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
(arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
(arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
>>> dict( # doctest: +SKIP
graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
)
{"original-X": arr,
('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
"""
chunks = normalize_chunks(chunks, shape, dtype=dtype)
out_ind = tuple(range(len(shape)))
if (
has_keyword(getitem, "asarray")
and has_keyword(getitem, "lock")
and (not asarray or lock)
):
kwargs = {"asarray": asarray, "lock": lock}
else:
# Common case, drop extra parameters
kwargs = {}
if inline_array:
layer = core_blockwise(
getitem,
name,
out_ind,
arr,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
return HighLevelGraph.from_collections(name, layer)
else:
original_name = "original-" + name
layers = {}
layers[original_name] = MaterializedLayer({original_name: arr})
layers[name] = core_blockwise(
getitem,
name,
out_ind,
original_name,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
deps = {
original_name: set(),
name: {original_name},
}
return HighLevelGraph(layers, deps)
def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
"""Dot product of many aligned chunks
>>> x = np.array([[1, 2], [1, 2]])
>>> y = np.array([[10, 20], [10, 20]])
>>> dotmany([x, x, x], [y, y, y])
array([[ 90, 180],
[ 90, 180]])
Optionally pass in functions to apply to the left and right chunks
>>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
array([[150, 150],
[150, 150]])
"""
if leftfunc:
A = map(leftfunc, A)
if rightfunc:
B = map(rightfunc, B)
return sum(map(partial(np.dot, **kwargs), A, B))
def _concatenate2(arrays, axes=None):
"""Recursively concatenate nested lists of arrays along axes
Each entry in axes corresponds to each level of the nested list. The
length of axes should correspond to the level of nesting of arrays.
If axes is an empty list or tuple, return arrays, or arrays[0] if
arrays is a list.
>>> x = np.array([[1, 2], [3, 4]])
>>> _concatenate2([x, x], axes=[0])
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
>>> _concatenate2([x, x], axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
>>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
array([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
Supports Iterators
>>> _concatenate2(iter([x, x]), axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
Special Case
>>> _concatenate2([x, x], axes=())
array([[1, 2],
[3, 4]])
"""
if axes is None:
axes = []
if axes == ():
if isinstance(arrays, list):
return arrays[0]
else:
return arrays
if isinstance(arrays, Iterator):
arrays = list(arrays)
if not isinstance(arrays, (list, tuple)):
return arrays
if len(axes) > 1:
arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
concatenate = concatenate_lookup.dispatch(
type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
)
if isinstance(arrays[0], dict):
# Handle concatenation of `dict`s, used as a replacement for structured
# arrays when that's not supported by the array library (e.g., CuPy).
keys = list(arrays[0].keys())
assert all(list(a.keys()) == keys for a in arrays)
ret = dict()
for k in keys:
ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
return ret
else:
return concatenate(arrays, axis=axes[0])
def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
"""
Tries to infer output dtype of ``func`` for a small set of input arguments.
Parameters
----------
func: Callable
Function for which output dtype is to be determined
args: List of array like
Arguments to the function, which would usually be used. Only attributes
``ndim`` and ``dtype`` are used.
kwargs: dict
Additional ``kwargs`` to the ``func``
funcname: String
Name of calling function to improve potential error messages
suggest_dtype: None/False or String
If not ``None`` adds suggestion to potential error message to specify a dtype
via the specified kwarg. Defaults to ``'dtype'``.
nout: None or Int
``None`` if function returns single output, integer if many.
Defaults to ``None``.
Returns
-------
: dtype or List of dtype
One or many dtypes (depending on ``nout``)
"""
from dask.array.utils import meta_from_array
# make sure that every arg is an evaluated array
args = [
(
np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
if is_arraylike(x)
else x
)
for x in args
]
try:
with np.errstate(all="ignore"):
o = func(*args, **kwargs)
except Exception as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb = "".join(traceback.format_tb(exc_traceback))
suggest = (
(
"Please specify the dtype explicitly using the "
"`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
)
if suggest_dtype
else ""
)
msg = (
f"`dtype` inference failed in `{funcname}`.\n\n"
f"{suggest}"
"Original error is below:\n"
"------------------------\n"
f"{e!r}\n\n"
"Traceback:\n"
"---------\n"
f"{tb}"
)
else:
msg = None
if msg is not None:
raise ValueError(msg)
return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
def normalize_arg(x):
"""Normalize user provided arguments to blockwise or map_blocks
We do a few things:
1. If they are string literals that might collide with blockwise_token then we
quote them
2. IF they are large (as defined by sizeof) then we put them into the
graph on their own by using dask.delayed
"""
if is_dask_collection(x):
return x
elif isinstance(x, str) and re.match(r"_\d+", x):
return delayed(x)
elif isinstance(x, list) and len(x) >= 10:
return delayed(x)
elif sizeof(x) > 1e6:
return delayed(x)
else:
return x
def _pass_extra_kwargs(func, keys, *args, **kwargs):
"""Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
For each element of `keys`, a corresponding element of args is changed
to a keyword argument with that key, before all arguments re passed on
to `func`.
"""
kwargs.update(zip(keys, args))
return func(*args[len(keys) :], **kwargs)
def map_blocks(
func,
*args,
name=None,
token=None,
dtype=None,
chunks=None,
drop_axis=None,
new_axis=None,
enforce_ndim=False,
meta=None,
**kwargs,
):
"""Map a function across all blocks of a dask array.
Note that ``map_blocks`` will attempt to automatically determine the output
array type by calling ``func`` on 0-d versions of the inputs. Please refer to
the ``meta`` keyword argument below if you expect that the function will not
succeed when operating on 0-d arrays.
Parameters
----------
func : callable
Function to apply to every block in the array.
If ``func`` accepts ``block_info=`` or ``block_id=``
as keyword arguments, these will be passed dictionaries
containing information about input and output chunks/arrays
during computation. See examples for details.
args : dask arrays or other objects
dtype : np.dtype, optional
The ``dtype`` of the output array. It is recommended to provide this.
If not provided, will be inferred by applying the function to a small
set of fake data.
chunks : tuple, optional
Chunk shape of resulting blocks if the function does not preserve
shape. If not provided, the resulting array is assumed to have the same
block structure as the first input array.
drop_axis : number or iterable, optional
Dimensions lost by the function.
new_axis : number or iterable, optional
New dimensions created by the function. Note that these are applied
after ``drop_axis`` (if present). The size of each chunk along this
dimension will be set to 1. Please specify ``chunks`` if the individual
chunks have a different size.
enforce_ndim : bool, default False
Whether to enforce at runtime that the dimensionality of the array
produced by ``func`` actually matches that of the array returned by
``map_blocks``.
If True, this will raise an error when there is a mismatch.
token : string, optional
The key prefix to use for the output array. If not provided, will be
determined from the function name.
name : string, optional
The key name to use for the output array. Note that this fully
specifies the output key name, and must be unique. If not provided,
will be determined by a hash of the arguments.
meta : array-like, optional
The ``meta`` of the output array, when specified is expected to be an
array of the same type and dtype of that returned when calling ``.compute()``
on the array returned by this function. When not provided, ``meta`` will be
inferred by applying the function to a small set of fake data, usually a
0-d array. It's important to ensure that ``func`` can successfully complete
computation without raising exceptions when 0-d is passed to it, providing
``meta`` will be required otherwise. If the output type is known beforehand
(e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
**kwargs :
Other keyword arguments to pass to function. Values must be constants
(not dask.arrays)
See Also
--------
dask.array.map_overlap : Generalized operation with overlap between neighbors.
dask.array.blockwise : Generalized operation with control over block alignment.
Examples
--------
>>> import dask.array as da
>>> x = da.arange(6, chunks=3)
>>> x.map_blocks(lambda x: x * 2).compute()
array([ 0, 2, 4, 6, 8, 10])
The ``da.map_blocks`` function can also accept multiple arrays.
>>> d = da.arange(5, chunks=2)
>>> e = da.arange(5, chunks=2)
>>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
>>> f.compute()
array([ 0, 2, 6, 12, 20])
If the function changes shape of the blocks then you must provide chunks
explicitly.
>>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
You have a bit of freedom in specifying chunks. If all of the output chunk
sizes are the same, you can provide just that chunk size as a single tuple.
>>> a = da.arange(18, chunks=(6,))
>>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
If the function changes the dimension of the blocks you must specify the
created or destroyed dimensions.
>>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
... new_axis=[0, 2])
If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
add the necessary number of axes on the left.
Note that ``map_blocks()`` will concatenate chunks along axes specified by
the keyword parameter ``drop_axis`` prior to applying the function.
This is illustrated in the figure below:
.. image:: /images/map_blocks_drop_axis.png
Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
on an axis that is chunked. In that case, it is better not to use
``map_blocks`` but rather
``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
maintains a leaner memory footprint while it drops any axis.
Map_blocks aligns blocks by block positions without regard to shape. In the
following example we have two arrays with the same number of blocks but
with different shape and chunk sizes.
>>> x = da.arange(1000, chunks=(100,))
>>> y = da.arange(100, chunks=(10,))
The relevant attribute to match is numblocks.
>>> x.numblocks
(10,)
>>> y.numblocks
(10,)
If these match (up to broadcasting rules) then we can map arbitrary
functions across blocks
>>> def func(a, b):
... return np.array([a.max(), b.max()])
>>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
>>> _.compute()
array([ 99, 9, 199, 19, 299, 29, 399, 39, 499, 49, 599, 59, 699,
69, 799, 79, 899, 89, 999, 99])
Your block function can get information about where it is in the array by
accepting a special ``block_info`` or ``block_id`` keyword argument.
During computation, they will contain information about each of the input
and output chunks (and dask arrays) relevant to each call of ``func``.
>>> def func(block_info=None):
... pass
This will receive the following information:
>>> block_info # doctest: +SKIP
{0: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)]},
None: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)],
'chunk-shape': (100,),
'dtype': dtype('float64')}}
The keys to the ``block_info`` dictionary indicate which is the input and
output Dask array:
- **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
The dictionary key is ``0`` because that is the argument index corresponding
to the first input Dask array.
In cases where multiple Dask arrays have been passed as input to the function,
you can access them with the number corresponding to the input argument,
eg: ``block_info[1]``, ``block_info[2]``, etc.
(Note that if you pass multiple Dask arrays as input to map_blocks,
the arrays must match each other by having matching numbers of chunks,
along corresponding dimensions up to broadcasting rules.)
- **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
and contains information about the output chunks.
The output chunk shape and dtype may may be different than the input chunks.
For each dask array, ``block_info`` describes:
- ``shape``: the shape of the full Dask array,
- ``num-chunks``: the number of chunks of the full array in each dimension,
- ``chunk-location``: the chunk location (for example the fourth chunk over
in the first dimension), and
- ``array-location``: the array location within the full Dask array
(for example the slice corresponding to ``40:50``).
In addition to these, there are two extra parameters described by
``block_info`` for the output array (in ``block_info[None]``):
- ``chunk-shape``: the output chunk shape, and
- ``dtype``: the output dtype.
These features can be combined to synthesize an array from scratch, for
example:
>>> def func(block_info=None):
... loc = block_info[None]['array-location'][0]
... return np.arange(loc[0], loc[1])
>>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
>>> _.compute()
array([0, 1, 2, 3, 4, 5, 6, 7])
``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
>>> def func(block_id=None):
... pass
This will receive the following information:
>>> block_id # doctest: +SKIP
(4, 3)
You may specify the key name prefix of the resulting task in the graph with
the optional ``token`` keyword argument.
>>> x.map_blocks(lambda x: x + 1, name='increment')
dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
For functions that may not handle 0-d arrays, it's also possible to specify
``meta`` with an empty array matching the type of the expected result. In
the example below, ``func`` will result in an ``IndexError`` when computing
``meta``:
>>> rng = da.random.default_rng()
>>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
a ``dtype``:
>>> import cupy # doctest: +SKIP
>>> rng = da.random.default_rng(cupy.random.default_rng()) # doctest: +SKIP
>>> dt = np.float32
>>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt)) # doctest: +SKIP
dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
"""
if drop_axis is None:
drop_axis = []
if not callable(func):
msg = (
"First argument must be callable function, not %s\n"
"Usage: da.map_blocks(function, x)\n"
" or: da.map_blocks(function, x, y, z)"
)
raise TypeError(msg % type(func).__name__)
if token:
warnings.warn(
"The `token=` keyword to `map_blocks` has been moved to `name=`. "
"Please use `name=` instead as the `token=` keyword will be removed "
"in a future release.",
category=FutureWarning,
)
name = token
name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
new_axes = {}
if isinstance(drop_axis, Number):
drop_axis = [drop_axis]
if isinstance(new_axis, Number):
…tack
"""
from dask.array import wrap
seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
if not seq:
raise ValueError("Need array(s) to concatenate")
if axis is None:
seq = [a.flatten() for a in seq]
axis = 0
seq_metas = [meta_from_array(s) for s in seq]
_concatenate = concatenate_lookup.dispatch(
type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
)
meta = _concatenate(seq_metas, axis=axis)
# Promote types to match meta
seq = [a.astype(meta.dtype) for a in seq]
# Find output array shape
ndim = len(seq[0].shape)
shape = tuple(
sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
for i in range(ndim)
)
# Drop empty arrays
seq2 = [a for a in seq if a.size]
if not seq2:
seq2 = seq
if axis < 0:
axis = ndim + axis
if axis >= ndim:
msg = (
"Axis must be less than than number of dimensions"
"\nData has %d dimensions, but got axis=%d"
)
raise ValueError(msg % (ndim, axis))
n = len(seq2)
if n == 0:
try:
return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
except TypeError:
return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
elif n == 1:
return seq2[0]
if not allow_unknown_chunksizes and not all(
i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
for i in range(ndim)
):
if any(map(np.isnan, seq2[0].shape)):
raise ValueError(
"Tried to concatenate arrays with unknown"
" shape %s.\n\nTwo solutions:\n"
" 1. Force concatenation pass"
" allow_unknown_chunksizes=True.\n"
" 2. Compute shapes with "
"[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
)
raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
inds = [list(range(ndim)) for i in range(n)]
for i, ind in enumerate(inds):
ind[axis] = -(i + 1)
uc_args = list(concat(zip(seq2, inds)))
_, seq2 = unify_chunks(*uc_args, warn=False)
bds = [a.chunks for a in seq2]
chunks = (
seq2[0].chunks[:axis]
+ (sum((bd[axis] for bd in bds), ()),)
+ seq2[0].chunks[axis + 1 :]
)
cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
names = [a.name for a in seq2]
name = "concatenate-" + tokenize(names, axis)
keys = list(product([name], *[range(len(bd)) for bd in chunks]))
values = [
(names[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[1 : axis + 1]
+ (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[axis + 2 :]
for key in keys
]
dsk = dict(zip(keys, values))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
return Array(graph, name, chunks, meta=meta)
def load_store_chunk(
x: Any,
out: Any,
index: slice,
lock: Any,
return_stored: bool,
load_stored: bool,
):
"""
A function inserted in a Dask graph for storing a chunk.
Parameters
----------
x: array-like
An array (potentially a NumPy one)
out: array-like
Where to store results.
index: slice-like
Where to store result from ``x`` in ``out``.
lock: Lock-like or False
Lock to use before writing to ``out``.
return_stored: bool
Whether to return ``out``.
load_stored: bool
Whether to return the array stored in ``out``.
Ignored if ``return_stored`` is not ``True``.
Returns
-------
If return_stored=True and load_stored=False
out
If return_stored=True and load_stored=True
out[index]
If return_stored=False and compute=False
None
Examples
--------
>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
"""
if lock:
lock.acquire()
try:
if x is not None and x.size != 0:
if is_arraylike(x):
out[index] = x
else:
out[index] = np.asanyarray(x)
if return_stored and load_stored:
return out[index]
elif return_stored and not load_stored:
return out
else:
return None
finally:
if lock:
lock.release()
def store_chunk(
x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
):
return load_store_chunk(x, out, index, lock, return_stored, False)
A = TypeVar("A", bound=ArrayLike)
def load_chunk(out: A, index: slice, lock: Any) -> A:
return load_store_chunk(None, out, index, lock, True, True)
def insert_to_ooc(
keys: list,
chunks: tuple[tuple[int, ...], ...],
out: ArrayLike,
name: str,
*,
lock: Lock | bool = True,
region: tuple[slice, ...] | slice | None = None,
return_stored: bool = False,
load_stored: bool = False,
) -> dict:
"""
Creates a Dask graph for storing chunks from ``arr`` in ``out``.
Parameters
----------
keys: list
Dask keys of the input array
chunks: tuple
Dask chunks of the input array
out: array-like
Where to store results to
name: str
First element of dask keys
lock: Lock-like or bool, optional
Whether to lock or with what (default is ``True``,
which means a :class:`threading.Lock` instance).
region: slice-like, optional
Where in ``out`` to store ``arr``'s results
(default is ``None``, meaning all of ``out``).
return_stored: bool, optional
Whether to return ``out``
(default is ``False``, meaning ``None`` is returned).
load_stored: bool, optional
Whether to handling loading from ``out`` at the same time.
Ignored if ``return_stored`` is not ``True``.
(default is ``False``, meaning defer to ``return_stored``).
Returns
-------
dask graph of store operation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
"""
if lock is True:
lock = Lock()
slices = slices_from_chunks(chunks)
if region:
slices = [fuse_slice(region, slc) for slc in slices]
if return_stored and load_stored:
func = load_store_chunk
args = (load_stored,)
else:
func = store_chunk # type: ignore
args = () # type: ignore
dsk = {
(name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
for t, slc in zip(core.flatten(keys), slices)
}
return dsk
def retrieve_from_ooc(
keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
) -> dict[tuple, Any]:
"""
Creates a Dask graph for loading stored ``keys`` from ``dsk``.
Parameters
----------
keys: Collection
A sequence containing Dask graph keys to load
dsk_pre: Mapping
A Dask graph corresponding to a Dask Array before computation
dsk_post: Mapping
A Dask graph corresponding to a Dask Array after computation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
>>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
"""
load_dsk = {
("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1] # type: ignore
for k in keys
}
return load_dsk
def _as_dtype(a, dtype):
if dtype is None:
return a
else:
return a.astype(dtype)
def asarray(
a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
):
"""Convert the input to a dask array.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
allow_unknown_chunksizes: bool
Allow unknown chunksizes, such as come from converting from dask
dataframes. Dask.array is unable to verify that chunks line up. If
data comes from differently aligned sources then this can cause
unexpected results.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(
stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(a, getitem=getter_inline, **kwargs)
def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
"""Convert the input to a dask array.
Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
inline_array:
Whether to inline the array in the resulting dask graph. For more information,
see the documentation for ``dask.array.from_array()``.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asanyarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asanyarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(stack(a), dtype)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asanyarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(
a,
chunks=a.shape,
getitem=getter_inline,
asarray=False,
inline_array=inline_array,
)
def is_scalar_for_elemwise(arg):
"""
>>> is_scalar_for_elemwise(42)
True
>>> is_scalar_for_elemwise('foo')
True
>>> is_scalar_for_elemwise(True)
True
>>> is_scalar_for_elemwise(np.array(42))
True
>>> is_scalar_for_elemwise([1, 2, 3])
True
>>> is_scalar_for_elemwise(np.array([1, 2, 3]))
False
>>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
False
>>> is_scalar_for_elemwise(np.dtype('i4'))
True
"""
# the second half of shape_condition is essentially just to ensure that
# dask series / frame are treated as scalars in elemwise.
maybe_shape = getattr(arg, "shape", None)
shape_condition = not isinstance(maybe_shape, Iterable) or any(
is_dask_collection(x) for x in maybe_shape
)
return (
np.isscalar(arg)
or shape_condition
or isinstance(arg, np.dtype)
or (isinstance(arg, np.ndarray) and arg.ndim == 0)
)
def broadcast_shapes(*shapes):
"""
Determines output shape from broadcasting arrays.
Parameters
----------
shapes : tuples
The shapes of the arguments.
Returns
-------
output_shape : tuple
Raises
------
ValueError
If the input shapes cannot be successfully broadcast together.
"""
if len(shapes) == 1:
return shapes[0]
out = []
for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
if np.isnan(sizes).any():
dim = np.nan
else:
dim = 0 if 0 in sizes else np.max(sizes).item()
if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
raise ValueError(
"operands could not be broadcast together with "
"shapes {}".format(" ".join(map(str, shapes)))
)
out.append(dim)
return tuple(reversed(out))
def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
"""Apply an elementwise ufunc-like function blockwise across arguments.
Like numpy ufuncs, broadcasting rules are respected.
Parameters
----------
op : callable
The function to apply. Should be numpy ufunc-like in the parameters
that it accepts.
*args : Any
Arguments to pass to `op`. Non-dask array-like objects are first
converted to dask arrays, then all arrays are broadcast together before
applying the function blockwise across all arguments. Any scalar
arguments are passed as-is following normal numpy ufunc behavior.
out : dask array, optional
If out is a dask.array then this overwrites the contents of that array
with the result.
where : array_like, optional
An optional boolean mask marking locations where the ufunc should be
applied. Can be a scalar, dask array, or any other array-like object.
Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
for more information.
dtype : dtype, optional
If provided, overrides the output array dtype.
name : str, optional
A unique key name to use when building the backing dask graph. If not
provided, one will be automatically generated based on the input
arguments.
Examples
--------
>>> elemwise(add, x, y) # doctest: +SKIP
>>> elemwise(sin, x) # doctest: +SKIP
>>> elemwise(sin, x, out=dask_array) # doctest: +SKIP
See Also
--------
blockwise
"""
if kwargs:
raise TypeError(
f"{op.__name__} does not take the following keyword arguments "
f"{sorted(kwargs)}"
)
out = _elemwise_normalize_out(out)
where = _elemwise_normalize_where(where)
args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
shapes = []
for arg in args:
shape = getattr(arg, "shape", ())
if any(is_dask_collection(x) for x in shape):
# Want to exclude Delayed shapes and dd.Scalar
shape = ()
shapes.append(shape)
if isinstance(where, Array):
shapes.append(where.shape)
if isinstance(out, Array):
shapes.append(out.shape)
shapes = [s if isinstance(s, Iterable) else () for s in shapes]
out_ndim = len(
broadcast_shapes(*shapes)
) # Raises ValueError if dimensions mismatch
expr_inds = tuple(range(out_ndim))[::-1]
if dtype is not None:
need_enforce_dtype = True
else:
# We follow NumPy's rules for dtype promotion, which special cases
# scalars and 0d ndarrays (which it considers equivalent) by using
# their values to compute the result dtype:
# https://github.com/numpy/numpy/issues/6240
# We don't inspect the values of 0d dask arrays, because these could
# hold potentially very expensive calculations. Instead, we treat
# them just like other arrays, and if necessary cast the result of op
# to match.
vals = [
(
np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
if not is_scalar_for_elemwise(a)
else a
)
for a in args
]
try:
dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
except Exception:
return NotImplemented
need_enforce_dtype = any(
not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
)
if not name:
name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
if where is not True:
blockwise_kwargs["elemwise_where_function"] = op
op = _elemwise_handle_where
args.extend([where, out])
if need_enforce_dtype:
blockwise_kwargs["enforce_dtype"] = dtype
blockwise_kwargs["enforce_dtype_function"] = op
op = _enforce_dtype
result = blockwise(
op,
expr_inds,
*concat(
(a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
for a in args
),
**blockwise_kwargs,
)
return handle_out(out, result)
def _elemwise_normalize_where(where):
if where is True:
return True
elif where is False or where is None:
return False
return asarray(where)
def _elemwise_handle_where(*args, **kwargs):
function = kwargs.pop("elemwise_where_function")
*args, where, out = args
if hasattr(out, "copy"):
out = out.copy()
return function(*args, where=where, out=out, **kwargs)
def _elemwise_normalize_out(out):
if isinstance(out, tuple):
if len(out) == 1:
out = out[0]
elif len(out) > 1:
raise NotImplementedError("The out parameter is not fully supported")
else:
out = None
if not (out is None or isinstance(out, Array)):
raise NotImplementedError(
f"The out parameter is not fully supported."
f" Received type {type(out).__name__}, expected Dask Array"
)
return out
def handle_out(out, result):
"""Handle out parameters
If out is a dask.array then this overwrites the contents of that array with
the result
"""
out = _elemwise_normalize_out(out)
if isinstance(out, Array):
if out.shape != result.shape:
raise ValueError(
"Mismatched shapes between result and out parameter. "
"out=%s, result=%s" % (str(out.shape), str(result.shape))
)
out._chunks = result.chunks
out.dask = result.dask
out._meta = result._meta
out._name = result.name
return out
else:
return result
def _enforce_dtype(*args, **kwargs):
"""Calls a function and converts its result to the given dtype.
The parameters have deliberately been given unwieldy names to avoid
clashes with keyword arguments consumed by blockwise
A dtype of `object` is treated as a special case and not enforced,
because it is used as a dummy value in some places when the result will
not be a block in an Array.
Parameters
----------
enforce_dtype : dtype
Result dtype
enforce_dtype_function : callable
The wrapped function, which will be passed the remaining arguments
"""
dtype = kwargs.pop("enforce_dtype")
function = kwargs.pop("enforce_dtype_function")
result = function(*args, **kwargs)
if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
if not np.can_cast(result, dtype, casting="same_kind"):
raise ValueError(
"Inferred dtype from function %r was %r "
"but got %r, which can't be cast using "
"casting='same_kind'"
% (funcname(function), str(dtype), str(result.dtype))
)
if np.isscalar(result):
# scalar astype method doesn't take the keyword arguments, so
# have to convert via 0-dimensional array and back.
result = result.astype(dtype)
else:
try:
result = result.astype(dtype, copy=False)
except TypeError:
# Missing copy kwarg
result = result.astype(dtype)
return result
def broadcast_to(x, shape, chunks=None, meta=None):
"""Broadcast an array to a new shape.
Parameters
----------
x : array_like
The array to broadcast.
shape : tuple
The shape of the desired array.
chunks : tuple, optional
If provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to is more efficient than rechunking afterwards. Chunks are
only allowed to differ from the original shape along dimensions that
are new on the result or have size 1 the input array.
meta : empty ndarray
empty ndarray created with same NumPy backend, ndim and dtype as the
Dask Array being created (overrides dtype)
Returns
-------
broadcast : dask array
See Also
--------
:func:`numpy.broadcast_to`
"""
x = asarray(x)
shape = tuple(shape)
if meta is None:
meta = meta_from_array(x)
if x.shape == shape and (chunks is None or chunks == x.chunks):
return x
ndim_new = len(shape) - x.ndim
if ndim_new < 0 or any(
new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
):
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
if chunks is None:
chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
bd if old > 1 else (new,)
for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
)
else:
chunks = normalize_chunks(
chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
)
for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
if old_bd != new_bd and old_bd != (1,):
raise ValueError(
"cannot broadcast chunks %s to chunks %s: "
"new chunks must either be along a new "
"dimension or a dimension of size 1" % (x.chunks, chunks)
)
name = "broadcast_to-" + tokenize(x, shape, chunks)
dsk = {}
enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
old_index = tuple(
0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
)
old_key = (x.name,) + old_index
new_key = (name,) + new_index
dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
@derived_from(np)
def broadcast_arrays(*args, subok=False):
subok = bool(subok)
to_array = asanyarray if subok else asarray
args = tuple(to_array(e) for e in args)
# Unify uneven chunking
inds = [list(reversed(range(x.ndim))) for x in args]
uc_args = concat(zip(args, inds))
_, args = unify_chunks(*uc_args, warn=False)
shape = broadcast_shapes(*(e.shape for e in args))
chunks = broadcast_chunks(*(e.chunks for e in args))
if NUMPY_GE_200:
result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
else:
result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
return result
def offset_func(func, offset, *args):
"""Offsets inputs by offset
>>> double = lambda x: x * 2
>>> f = offset_func(double, (10,))
>>> f(1)
22
>>> f(300)
620
"""
def _offset(*args):
args2 = list(map(add, args, offset))
return func(*args2)
with contextlib.suppress(Exception):
_offset.__name__ = "offset_" + func.__name__
return _offset
def chunks_from_arrays(arrays):
"""Chunks tuple from nested list of arrays
>>> x = np.array([1, 2])
>>> chunks_from_arrays([x, x])
((2, 2),)
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x], [x]])
((1, 1), (2,))
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x, x]])
((1,), (2, 2))
>>> chunks_from_arrays([1, 1])
((1, 1),)
"""
if not arrays:
return ()
result = []
dim = 0
def shape(x):
try:
return x.shape if x.shape else (1,)
except AttributeError:
return (1,)
while isinstance(arrays, (list, tuple)):
> result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E IndexError: tuple index out of range
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: IndexError
Check warning on line 0 in distributed.tests.test_client
github-actions / Unit Test Results
1 out of 12 runs failed: test_call_stack_future (distributed.tests.test_client)
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 1s]
Raw output
KeyError: 'slowinc-1ea6f50d354d90f0a53d5f5250e2c25f'
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:61467', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:61468', name: 0, status: closed, stored: 0, running: 1/4, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:61471', name: 1, status: closed, stored: 0, running: 1/4, ready: 0, comm: 0, waiting: 0>
@gen_cluster([("127.0.0.1", 4)] * 2, client=True)
async def test_call_stack_future(c, s, a, b):
x = c.submit(slowdec, 1, delay=0.5)
future = c.submit(slowinc, 1, delay=0.5)
await asyncio.sleep(0.1)
> results = await asyncio.gather(
c.call_stack(future), c.call_stack(keys=[future.key])
)
distributed\tests\test_client.py:5467:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:834: in _handle_comm
result = await result
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import heapq
import inspect
import itertools
import json
import logging
import math
import operator
import os
import pickle
import random
import textwrap
import uuid
import warnings
import weakref
from abc import abstractmethod
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
Set,
)
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, cast, overload
import psutil
import tornado.web
from sortedcontainers import SortedDict, SortedSet
from tlz import (
concat,
first,
groupby,
merge,
merge_sorted,
merge_with,
partition,
pluck,
second,
take,
valmap,
)
from tornado.ioloop import IOLoop
import dask
import dask.utils
from dask._task_spec import DependenciesMapping, GraphNode, convert_legacy_graph
from dask.base import TokenizationError, normalize_token, tokenize
from dask.core import istask, validate_key
from dask.typing import Key, no_default
from dask.utils import (
_deprecated,
_deprecated_kwarg,
ensure_dict,
format_bytes,
format_time,
key_split,
parse_bytes,
parse_timedelta,
tmpfile,
)
from dask.widgets import get_template
from distributed import cluster_dump, preloading, profile
from distributed import versions as version_module
from distributed._asyncio import RLock
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.broker import Broker
from distributed.client import SourceCode
from distributed.collections import HeapSet
from distributed.comm import (
Comm,
CommClosedError,
get_address_host,
normalize_address,
resolve_address,
unparse_host_port,
)
from distributed.comm.addressing import addresses_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ErrorMessage,
OKMessage,
Status,
clean_exception,
error_message,
rpc,
send_recv,
)
from distributed.diagnostics.memory_sampler import MemorySamplerExtension
from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name
from distributed.event import EventExtension
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import monotonic, time
from distributed.multi_lock import MultiLockExtension
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import deserialize
from distributed.protocol.pickle import dumps, loads
from distributed.protocol.serialize import Serialized, ToPickle, serialize
from distributed.publish import PublishExtension
from distributed.pubsub import PubSubSchedulerExtension
from distributed.queues import QueueExtension
from distributed.recreate_tasks import ReplayTaskScheduler
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerPlugin
from distributed.spans import SpanMetadata, SpansSchedulerExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
Deadline,
TimeoutError,
format_dashboard_link,
get_fileno_limit,
key_split_group,
log_errors,
offload,
recursive_to_dict,
wait_for,
)
from distributed.utils_comm import (
gather_from_workers,
retry_operation,
scatter_to_workers,
)
from distributed.variable import VariableExtension
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
# TODO import from typing (requires Python >=3.11)
from typing_extensions import Self, TypeAlias
from dask.highlevelgraph import HighLevelGraph
# Not to be confused with distributed.worker_state_machine.TaskStateState
TaskStateState: TypeAlias = Literal[
"released",
"waiting",
"no-worker",
"queued",
"processing",
"memory",
"erred",
"forgotten",
]
ALL_TASK_STATES: Set[TaskStateState] = set(TaskStateState.__args__) # type: ignore
# {task key -> finish state}
# Not to be confused with distributed.worker_state_machine.Recs
Recs: TypeAlias = dict[Key, TaskStateState]
# {client or worker address: [{op: <key>, ...}, ...]}
Msgs: TypeAlias = dict[str, list[dict[str, Any]]]
# (recommendations, client messages, worker messages)
RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs]
T_runspec: TypeAlias = GraphNode
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_DATA_SIZE = parse_bytes(
dask.config.get("distributed.scheduler.default-data-size")
)
STIMULUS_ID_UNSET = "<stimulus_id unset>"
DEFAULT_EXTENSIONS = {
"multi_locks": MultiLockExtension,
"publish": PublishExtension,
"replay-tasks": ReplayTaskScheduler,
"queues": QueueExtension,
"variables": VariableExtension,
"pubsub": PubSubSchedulerExtension,
"semaphores": SemaphoreExtension,
"events": EventExtension,
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerPlugin,
"spans": SpansSchedulerExtension,
"stealing": WorkStealing,
}
class ClientState:
"""A simple object holding information about a client."""
#: A unique identifier for this client. This is generally an opaque
#: string generated by the client itself.
client_key: str
#: Cached hash of :attr:`~ClientState.client_key`
_hash: int
#: A set of tasks this client wants to be kept in memory, so that it can download
#: its result when desired. This is the reverse mapping of
#: :class:`TaskState.who_wants`. Tasks are typically removed from this set when the
#: corresponding object in the client's space (for example a ``Future`` or a Dask
#: collection) gets garbage-collected.
wants_what: set[TaskState]
#: The last time we received a heartbeat from this client, in local scheduler time.
last_seen: float
#: Output of :func:`distributed.versions.get_versions` on the client
versions: dict[str, Any]
__slots__ = tuple(__annotations__)
def __init__(self, client: str, *, versions: dict[str, Any] | None = None):
self.client_key = client
self._hash = hash(client)
self.wants_what = set()
self.last_seen = time()
self.versions = versions or {}
def __hash__(self) -> int:
return self._hash
def __eq__(self, other: object) -> bool:
if not isinstance(other, ClientState):
return False
return self.client_key == other.client_key
def __repr__(self) -> str:
return f"<Client {self.client_key!r}>"
def __str__(self) -> str:
return self.client_key
def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict:
"""Dictionary representation for debugging purposes.
Not type stable and not intended for roundtrips.
See also
--------
Client.dump_cluster_state
distributed.utils.recursive_to_dict
TaskState._to_dict
"""
return recursive_to_dict(
self,
exclude=set(exclude) | {"versions"}, # type: ignore
members=True,
)
class MemoryState:
"""Memory readings on a worker or on the whole cluster.
See :doc:`worker-memory`.
Attributes / properties:
managed_total
Sum of the output of sizeof() for all dask keys held by the worker in memory,
plus number of bytes spilled to disk
managed
Sum of the output of sizeof() for the dask keys held in RAM. Note that this may
be inaccurate, which may cause inaccurate unmanaged memory (see below).
spilled
Number of bytes for the dask keys spilled to the hard drive.
Note that this is the size on disk; size in memory may be different due to
compression and inaccuracies in sizeof(). In other words, given the same keys,
'managed' will change depending on the keys being in memory or spilled.
process
Total RSS memory measured by the OS on the worker process.
This is always exactly equal to managed + unmanaged.
unmanaged
process - managed. This is the sum of
- Python interpreter and modules
- global variables
- memory temporarily allocated by the dask tasks that are currently running
- memory fragmentation
- memory leaks
- memory not yet garbage collected
- memory not yet free()'d by the Python memory manager to the OS
unmanaged_old
Minimum of the 'unmanaged' measures over the last
``distributed.memory.recent-to-old-time`` seconds
unmanaged_recent
unmanaged - unmanaged_old; in other words process memory that has been recently
allocated but is not accounted for by dask; hopefully it's mostly a temporary
spike.
optimistic
managed + unmanaged_old; in other words the memory held long-term by
the process under the hopeful assumption that all unmanaged_recent memory is a
temporary spike
"""
process: int
unmanaged_old: int
managed: int
spilled: int
__slots__ = tuple(__annotations__)
def __init__(
self,
*,
process: int,
unmanaged_old: int,
managed: int,
spilled: int,
):
# Some data arrives with the heartbeat, some other arrives in realtime as the
# tasks progress. Also, sizeof() is not guaranteed to return correct results.
# This can cause glitches where a partial measure is larger than the whole, so
# we need to force all numbers to add up exactly by definition.
self.process = process
self.managed = min(self.process, managed)
self.spilled = spilled
# Subtractions between unsigned ints guaranteed by construction to be >= 0
self.unmanaged_old = min(unmanaged_old, process - self.managed)
@staticmethod
def sum(*infos: MemoryState) -> MemoryState:
process = 0
unmanaged_old = 0
managed = 0
spilled = 0
for ms in infos:
process += ms.process
unmanaged_old += ms.unmanaged_old
spilled += ms.spilled
managed += ms.managed
return MemoryState(
process=process,
unmanaged_old=unmanaged_old,
managed=managed,
spilled=spilled,
)
@property
def managed_total(self) -> int:
return self.managed + self.spilled
@property
def unmanaged(self) -> int:
# This is never negative thanks to __init__
return self.process - self.managed
@property
def unmanaged_recent(self) -> int:
# This is never negative thanks to __init__
return self.process - self.managed - self.unmanaged_old
@property
def optimistic(self) -> int:
return self.managed + self.unmanaged_old
@property
def managed_in_memory(self) -> int:
warnings.warn("managed_in_memory has been renamed to managed", FutureWarning)
return self.managed
@property
def managed_spilled(self) -> int:
warnings.warn("managed_spilled has been renamed to spilled", FutureWarning)
return self.spilled
def __repr__(self) -> str:
return (
f"Process memory (RSS) : {format_bytes(self.process)}\n"
f" - managed by Dask : {format_bytes(self.managed)}\n"
f" - unmanaged (old) : {format_bytes(self.unmanaged_old)}\n"
f" - unmanaged (recent): {format_bytes(self.unmanaged_recent)}\n"
f"Spilled to disk : {format_bytes(self.spilled)}\n"
)
def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"""Dictionary representation for debugging purposes.
See also
--------
Client.dump_cluster_state
distributed.utils.recursive_to_dict
"""
return {
k: getattr(self, k)
for k in dir(self)
if not k.startswith("_")
and k not in {"sum", "managed_in_memory", "managed_spilled"}
}
class WorkerState:
"""A simple object holding information about a worker.
Not to be confused with :class:`distributed.worker_state_machine.WorkerState`.
"""
#: This worker's unique key. This can be its connected address
#: (such as ``"tcp://127.0.0.1:8891"``) or an alias (such as ``"alice"``).
address: str
pid: int
name: Hashable
#: The number of CPU threads made available on this worker
nthreads: int
#: Memory available to the worker, in bytes
memory_limit: int
local_directory: str
services: dict[str, int]
#: Output of :meth:`distributed.versions.get_versions` on the worker
versions: dict[str, Any]
#: Address of the associated :class:`~distributed.nanny.Nanny`, if present
nanny: str | None
#: Read-only worker status, synced one way from the remote Worker object
status: Status
#: Cached hash of :attr:`~WorkerState.server_id`
_hash: int
#: The total memory size, in bytes, used by the tasks this worker holds in memory
#: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`).
nbytes: int
#: Worker memory unknown to the worker, in bytes, which has been there for more than
#: 30 seconds. See :class:`MemoryState`.
_memory_unmanaged_old: int
#: History of the last 30 seconds' worth of unmanaged memory. Used to differentiate
#: between "old" and "new" unmanaged memory.
#: Format: ``[(timestamp, bytes), (timestamp, bytes), ...]``
_memory_unmanaged_history: deque[tuple[float, int]]
metrics: dict[str, Any]
#: The last time we received a heartbeat from this worker, in local scheduler time.
last_seen: float
time_delay: float
bandwidth: float
#: A set of all TaskStates on this worker that are actors. This only includes those
#: actors whose state actually lives on this worker, not actors to which this worker
#: has a reference.
actors: set[TaskState]
#: Underlying data of :meth:`WorkerState.has_what`
_has_what: dict[TaskState, None]
#: A set of tasks that have been submitted to this worker. Multiple tasks may be
# submitted to a worker in advance and the worker will run them eventually,
# depending on its execution resources (but see :doc:`work-stealing`).
#:
#: All the tasks here are in the "processing" state.
#: This attribute is kept in sync with :attr:`TaskState.processing_on`.
processing: set[TaskState]
#: Running tasks that invoked :func:`distributed.secede`
long_running: set[TaskState]
#: A dictionary of tasks that are currently being run on this worker.
#: Each task state is associated with the duration in seconds which the task has
#: been running.
executing: dict[TaskState, float]
#: The available resources on this worker, e.g. ``{"GPU": 2}``.
#: These are abstract quantities that constrain certain tasks from running at the
#: same time on this worker.
resources: dict[str, float]
#: The sum of each resource used by all tasks allocated to this worker.
#: The numbers in this dictionary can only be less or equal than those in this
#: worker's :attr:`~WorkerState.resources`.
used_resources: dict[str, float]
#: Arbitrary additional metadata to be added to :meth:`~WorkerState.identity`
extra: dict[str, Any]
# The unique server ID this WorkerState is referencing
server_id: str
# Reference to scheduler task_groups
scheduler_ref: weakref.ref[SchedulerState] | None
task_prefix_count: defaultdict[str, int]
_network_occ: float
_occupancy_cache: float | None
#: Keys that may need to be fetched to this worker, and the number of tasks that need them.
#: All tasks are currently in `memory` on a worker other than this one.
#: Much like `processing`, this does not exactly reflect worker state:
#: keys here may be queued to fetch, in flight, or already in memory
#: on the worker.
needs_what: dict[TaskState, int]
__slots__ = tuple(__annotations__)
def __init__(
self,
*,
address: str,
status: Status,
pid: int,
name: object,
nthreads: int = 0,
memory_limit: int,
local_directory: str,
nanny: str | None,
server_id: str,
services: dict[str, int] | None = None,
versions: dict[str, Any] | None = None,
extra: dict[str, Any] | None = None,
scheduler: SchedulerState | None = None,
):
self.server_id = server_id
self.address = address
self.pid = pid
self.name = name
self.nthreads = nthreads
self.memory_limit = memory_limit
self.local_directory = local_directory
self.services = services or {}
self.versions = versions or {}
self.nanny = nanny
self.status = status
self._hash = hash(self.server_id)
self.nbytes = 0
self._memory_unmanaged_old = 0
self._memory_unmanaged_history = deque()
self.metrics = {}
self.last_seen = time()
self.time_delay = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.actors = set()
self._has_what = {}
self.processing = set()
self.long_running = set()
self.executing = {}
self.resources = {}
self.used_resources = {}
self.extra = extra or {}
self.scheduler_ref = weakref.ref(scheduler) if scheduler else None
self.task_prefix_count = defaultdict(int)
self.needs_what = {}
self._network_occ = 0
self._occupancy_cache = None
def __hash__(self) -> int:
return self._hash
def __eq__(self, other: object) -> bool:
return self is other or (
isinstance(other, WorkerState) and other.server_id == self.server_id
)
@property
def has_what(self) -> Set[TaskState]:
"""An insertion-sorted set-like of tasks which currently reside on this worker.
All the tasks here are in the "memory" state.
This is the reverse mapping of :attr:`TaskState.who_has`.
This is a read-only public accessor. The data is implemented as a dict without
values, because rebalance() relies on dicts being insertion-sorted.
"""
return self._has_what.keys()
@property
def host(self) -> str:
return get_address_host(self.address)
@property
def memory(self) -> MemoryState:
"""Polished memory metrics for the worker.
**Design note on managed memory**
There are two measures available for managed memory:
- ``self.nbytes``
- ``self.metrics["managed_bytes"]``
At rest, the two numbers must be identical. However, ``self.nbytes`` is
immediately updated through the batched comms as soon as each task lands in
memory on the worker; ``self.metrics["managed_bytes"]`` instead is updated by
the heartbeat, which can lag several seconds behind.
Below we are mixing likely newer managed memory info from ``self.nbytes`` with
process and spilled memory from the heartbeat. This is deliberate, so that
managed memory total is updated more frequently.
Managed memory directly and immediately contributes to optimistic memory, which
is in turn used in Active Memory Manager heuristics (at the moment of writing;
more uses will likely be added in the future). So it's important to have it
up to date; much more than it is for process memory.
Having up-to-date managed memory info as soon as the scheduler learns about
task completion also substantially simplifies unit tests.
The flip side of this design is that it may cause some noise in the
unmanaged_recent measure. e.g.:
1. Delete 100MB of managed data
2. The updated managed memory reaches the scheduler faster than the
updated process memory
3. There's a blip where the scheduler thinks that there's a sudden 100MB
increase in unmanaged_recent, since process memory hasn't changed but managed
memory has decreased by 100MB
4. When the heartbeat arrives, process memory goes down and so does the
unmanaged_recent.
This is OK - one of the main reasons for the unmanaged_recent / unmanaged_old
split is exactly to concentrate all the noise in unmanaged_recent and exclude it
from optimistic memory, which is used for heuristics.
Something that is less OK, but also less frequent, is that the sudden deletion
of spilled keys will cause a negative blip in managed memory:
1. Delete 100MB of spilled data
2. The updated managed memory *total* reaches the scheduler faster than the
updated spilled portion
3. This causes the managed memory to temporarily plummet and be replaced by
unmanaged_recent, while spilled memory remains unaltered
4. When the heartbeat arrives, managed goes back up, unmanaged_recent
goes back down, and spilled goes down by 100MB as it should have to
begin with.
:issue:`6002` will let us solve this.
"""
return MemoryState(
process=self.metrics["memory"],
managed=max(0, self.nbytes - self.metrics["spilled_bytes"]["memory"]),
spilled=self.metrics["spilled_bytes"]["disk"],
unmanaged_old=self._memory_unmanaged_old,
)
def clean(self) -> WorkerState:
"""Return a version of this object that is appropriate for serialization"""
ws = WorkerState(
address=self.address,
status=self.status,
pid=self.pid,
name=self.name,
nthreads=self.nthreads,
memory_limit=self.memory_limit,
local_directory=self.local_directory,
services=self.services,
nanny=self.nanny,
extra=self.extra,
server_id=self.server_id,
)
ws._occupancy_cache = self.occupancy
ws.executing = {ts.key: duration for ts, duration in self.executing.items()} # type: ignore
return ws
def __repr__(self) -> str:
name = f", name: {self.name}" if self.name != self.address else ""
return (
f"<WorkerState {self.address!r}{name}, "
f"status: {self.status.name}, "
f"memory: {len(self.has_what)}, "
f"processing: {len(self.processing)}>"
)
def _repr_html_(self) -> str:
return get_template("worker_state.html.j2").render(
address=self.address,
name=self.name,
status=self.status.name,
has_what=self.has_what,
processing=self.processing,
)
def identity(self) -> dict[str, Any]:
return {
"type": "Worker",
"id": self.name,
"host": self.host,
"resources": self.resources,
"local_directory": self.local_directory,
"name": self.name,
"nthreads": self.nthreads,
"memory_limit": self.memory_limit,
"last_seen": self.last_seen,
"services": self.services,
"metrics": self.metrics,
"status": self.status.name,
"nanny": self.nanny,
**self.extra,
}
def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
"""Dictionary representation for debugging purposes.
Not type stable and not intended for roundtrips.
See also
--------
Client.dump_cluster_state
distributed.utils.recursive_to_dict
TaskState._to_dict
"""
return recursive_to_dict(
self,
exclude=set(exclude) | {"versions"}, # type: ignore
members=True,
)
@property
def scheduler(self) -> SchedulerState:
assert self.scheduler_ref
s = self.scheduler_ref()
assert s
return s
def add_to_processing(self, ts: TaskState) -> None:
"""Assign a task to this worker for compute."""
if self.scheduler.validate:
assert ts not in self.processing
tp = ts.prefix
self.task_prefix_count[tp.name] += 1
self.scheduler._task_prefix_count_global[tp.name] += 1
self.processing.add(ts)
for dts in ts.dependencies:
assert dts.who_has
if self not in dts.who_has:
self._inc_needs_replica(dts)
def add_to_long_running(self, ts: TaskState) -> None:
if self.scheduler.validate:
assert ts in self.processing
assert ts not in self.long_running
self._remove_from_task_prefix_count(ts)
# Cannot remove from processing since we're using this for things like
# idleness detection. Idle workers are typically targeted for
# downscaling but we should not downscale workers with long running
# tasks
self.long_running.add(ts)
def remove_from_processing(self, ts: TaskState) -> None:
"""Remove a task from a workers processing"""
if self.scheduler.validate:
assert ts in self.processing
if ts in self.long_running:
self.long_running.discard(ts)
else:
self._remove_from_task_prefix_count(ts)
self.processing.remove(ts)
for dts in ts.dependencies:
if dts in self.needs_what:
self._dec_needs_replica(dts)
def _remove_from_task_prefix_count(self, ts: TaskState) -> None:
prefix_name = ts.prefix.name
count = self.task_prefix_count[prefix_name] - 1
tp_count = self.task_prefix_count
tp_count_global = self.scheduler._task_prefix_count_global
if count:
tp_count[prefix_name] = count
else:
del tp_count[prefix_name]
count = tp_count_global[prefix_name] - 1
if count:
tp_count_global[prefix_name] = count
else:
del tp_count_global[prefix_name]
def remove_replica(self, ts: TaskState) -> None:
"""The worker no longer has a task in memory"""
if self.scheduler.validate:
assert ts.who_has
assert self in ts.who_has
assert ts in self.has_what
assert ts not in self.needs_what
self.nbytes -= ts.get_nbytes()
del self._has_what[ts]
ts.who_has.remove(self) # type: ignore
if not ts.who_has:
ts.who_has = None
def _inc_needs_replica(self, ts: TaskState) -> None:
"""Assign a task fetch to this worker and update network occupancies"""
if self.scheduler.validate:
assert ts.who_has
assert self not in ts.who_has
assert ts not in self.has_what
if ts not in self.needs_what:
self.needs_what[ts] = 1
nbyte… # See definition of recipients above
heapq.heapreplace(
recipients,
(rec_bytes_max, rec_bytes_min, id(rec_ws), rec_ws),
)
else:
heapq.heappop(recipients)
# Move to next sender with the most data to lose.
# It may or may not be the same sender again.
break
else: # for ts in ts_iter
# Exhausted tasks on this sender
heapq.heappop(senders)
return msgs
async def _rebalance_move_data(
self, msgs: list[tuple[WorkerState, WorkerState, TaskState]], stimulus_id: str
) -> dict:
"""Perform the actual transfer of data across the network in rebalance().
Takes in input the output of _rebalance_find_msgs(), that is a list of tuples:
- sender worker
- recipient worker
- task to be transferred
FIXME this method is not robust when the cluster is not idle.
"""
# {recipient address: {key: [sender address, ...]}}
to_recipients: defaultdict[str, defaultdict[Key, list[str]]] = defaultdict(
lambda: defaultdict(list)
)
for snd_ws, rec_ws, ts in msgs:
to_recipients[rec_ws.address][ts.key].append(snd_ws.address)
failed_keys_by_recipient = dict(
zip(
to_recipients,
await asyncio.gather(
*(
# Note: this never raises exceptions
self.gather_on_worker(w, who_has)
for w, who_has in to_recipients.items()
)
),
)
)
to_senders = defaultdict(list)
for snd_ws, rec_ws, ts in msgs:
if ts.key not in failed_keys_by_recipient[rec_ws.address]:
to_senders[snd_ws.address].append(ts.key)
# Note: this never raises exceptions
await asyncio.gather(
*(self.delete_worker_data(r, v, stimulus_id) for r, v in to_senders.items())
)
for r, v in to_recipients.items():
self.log_event(r, {"action": "rebalance", "who_has": v})
self.log_event(
"all",
{
"action": "rebalance",
"senders": valmap(len, to_senders),
"recipients": valmap(len, to_recipients),
"moved_keys": len(msgs),
},
)
missing_keys = {k for r in failed_keys_by_recipient.values() for k in r}
if missing_keys:
return {"status": "partial-fail", "keys": list(missing_keys)}
else:
return {"status": "OK"}
async def replicate(
self,
comm=None,
keys=None,
n=None,
workers=None,
branching_factor=2,
delete=True,
stimulus_id=None,
):
"""Replicate data throughout cluster
This performs a tree copy of the data throughout the network
individually on each piece of data.
Parameters
----------
keys: Iterable
list of keys to replicate
n: int
Number of replications we expect to see within the cluster
branching_factor: int, optional
The number of workers that can copy data in each generation.
The larger the branching factor, the more data we copy in
a single step, but the more a given worker risks being
swamped by data requests.
See also
--------
Scheduler.rebalance
"""
stimulus_id = stimulus_id or f"replicate-{time()}"
assert branching_factor > 0
# Downgrade reentrant lock to non-reentrant
async with self._replica_lock(("replicate", object())):
if workers is not None:
workers = {self.workers[w] for w in self.workers_list(workers)}
workers = {ws for ws in workers if ws.status == Status.running}
else:
workers = self.running
if n is None:
n = len(workers)
else:
n = min(n, len(workers))
if n == 0:
raise ValueError("Can not use replicate to delete data")
tasks = {self.tasks[k] for k in keys}
missing_data = [ts.key for ts in tasks if not ts.who_has]
if missing_data:
return {"status": "partial-fail", "keys": missing_data}
# Delete extraneous data
if delete:
del_worker_tasks = defaultdict(set)
for ts in tasks:
del_candidates = tuple(ts.who_has & workers)
if len(del_candidates) > n:
for ws in random.sample(
del_candidates, len(del_candidates) - n
):
del_worker_tasks[ws].add(ts)
# Note: this never raises exceptions
await asyncio.gather(
*[
self.delete_worker_data(
ws.address, [t.key for t in tasks], stimulus_id
)
for ws, tasks in del_worker_tasks.items()
]
)
# Copy not-yet-filled data
while tasks:
gathers = defaultdict(dict)
for ts in list(tasks):
if ts.state == "forgotten":
# task is no longer needed by any client or dependent task
tasks.remove(ts)
continue
assert ts.who_has is not None
n_missing = n - len(ts.who_has & workers)
if n_missing <= 0:
# Already replicated enough
tasks.remove(ts)
continue
count = min(n_missing, branching_factor * len(ts.who_has))
assert count > 0
for ws in random.sample(tuple(workers - ts.who_has), count):
gathers[ws.address][ts.key] = [
wws.address for wws in ts.who_has
]
await asyncio.gather(
*(
# Note: this never raises exceptions
self.gather_on_worker(w, who_has)
for w, who_has in gathers.items()
)
)
for r, v in gathers.items():
self.log_event(r, {"action": "replicate-add", "who_has": v})
self.log_event(
"all",
{
"action": "replicate",
"workers": list(workers),
"key-count": len(keys),
"branching-factor": branching_factor,
},
)
@log_errors
def workers_to_close(
self,
memory_ratio: int | float | None = None,
n: int | None = None,
key: Callable[[WorkerState], Hashable] | bytes | None = None,
minimum: int | None = None,
target: int | None = None,
attribute: str = "address",
) -> list[str]:
"""
Find workers that we can close with low cost
This returns a list of workers that are good candidates to retire.
These workers are not running anything and are storing
relatively little data relative to their peers. If all workers are
idle then we still maintain enough workers to have enough RAM to store
our data, with a comfortable buffer.
This is for use with systems like ``distributed.deploy.adaptive``.
Parameters
----------
memory_ratio : Number
Amount of extra space we want to have for our stored data.
Defaults to 2, or that we want to have twice as much memory as we
currently have data.
n : int
Number of workers to close
minimum : int
Minimum number of workers to keep around
key : Callable(WorkerState)
An optional callable mapping a WorkerState object to a group
affiliation. Groups will be closed together. This is useful when
closing workers must be done collectively, such as by hostname.
target : int
Target number of workers to have after we close
attribute : str
The attribute of the WorkerState object to return, like "address"
or "name". Defaults to "address".
Examples
--------
>>> scheduler.workers_to_close()
['tcp://192.168.0.1:1234', 'tcp://192.168.0.2:1234']
Group workers by hostname prior to closing
>>> scheduler.workers_to_close(key=lambda ws: ws.host)
['tcp://192.168.0.1:1234', 'tcp://192.168.0.1:4567']
Remove two workers
>>> scheduler.workers_to_close(n=2)
Keep enough workers to have twice as much memory as we we need.
>>> scheduler.workers_to_close(memory_ratio=2)
Returns
-------
to_close: list of worker addresses that are OK to close
See Also
--------
Scheduler.retire_workers
"""
if target is not None and n is None:
n = len(self.workers) - target
if n is not None:
if n < 0:
n = 0
target = len(self.workers) - n
if n is None and memory_ratio is None:
memory_ratio = 2
if not n and all([ws.processing for ws in self.workers.values()]):
return []
if key is None:
key = operator.attrgetter("address")
if isinstance(key, bytes):
key = pickle.loads(key)
# Long running tasks typically use a worker_client to schedule
# other tasks. We should never shut down the worker they're
# running on, as it would cause them to restart from scratch
# somewhere else.
valid_workers = [ws for ws in self.workers.values() if not ws.long_running]
for plugin in list(self.plugins.values()):
valid_workers = plugin.valid_workers_downscaling(self, valid_workers)
groups = groupby(key, valid_workers)
limit_bytes = {k: sum(ws.memory_limit for ws in v) for k, v in groups.items()}
group_bytes = {k: sum(ws.nbytes for ws in v) for k, v in groups.items()}
limit = sum(limit_bytes.values())
total = sum(group_bytes.values())
def _key(group):
is_idle = not any([wws.processing for wws in groups[group]])
bytes = -group_bytes[group]
return is_idle, bytes
idle = sorted(groups, key=_key)
to_close = []
n_remain = len(self.workers)
while idle:
group = idle.pop()
if n is None and any([ws.processing for ws in groups[group]]):
break
if minimum and n_remain - len(groups[group]) < minimum:
break
limit -= limit_bytes[group]
if (n is not None and n_remain - len(groups[group]) >= (target or 0)) or (
memory_ratio is not None and limit >= memory_ratio * total
):
to_close.append(group)
n_remain -= len(groups[group])
else:
break
result = [getattr(ws, attribute) for g in to_close for ws in groups[g]]
if result:
logger.debug("Suggest closing workers: %s", result)
return result
@overload
async def retire_workers(
self,
workers: list[str],
*,
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
) -> list[str]: ...
@overload
async def retire_workers(
self,
*,
names: list,
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
) -> list[str]: ...
@overload
async def retire_workers(
self,
*,
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
# Parameters for workers_to_close()
memory_ratio: int | float | None = None,
n: int | None = None,
key: Callable[[WorkerState], Hashable] | bytes | None = None,
minimum: int | None = None,
target: int | None = None,
attribute: str = "address",
) -> list[str]: ...
@log_errors
async def retire_workers(
self,
workers: list[str] | None = None,
*,
names: list | None = None,
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
**kwargs: Any,
) -> list[str]:
"""Gracefully retire workers from cluster. Any key that is in memory exclusively
on the retired workers is replicated somewhere else.
Parameters
----------
workers: list[str] (optional)
List of worker addresses to retire.
names: list (optional)
List of worker names to retire.
Mutually exclusive with ``workers``.
If neither ``workers`` nor ``names`` are provided, we call
``workers_to_close`` which finds a good set.
close_workers: bool (defaults to False)
Whether to actually close the worker explicitly from here.
Otherwise, we expect some external job scheduler to finish off the worker.
remove: bool (defaults to True)
Whether to remove the worker metadata immediately or else wait for the
worker to contact us.
If close_workers=False and remove=False, this method just flushes the tasks
in memory out of the workers and then returns.
If close_workers=True and remove=False, this method will return while the
workers are still in the cluster, although they won't accept new tasks.
If close_workers=False or for whatever reason a worker doesn't accept the
close command, it will be left permanently unable to accept new tasks and
it is expected to be closed in some other way.
**kwargs: dict
Extra options to pass to workers_to_close to determine which
workers we should drop. Only accepted if ``workers`` and ``names`` are
omitted.
Returns
-------
Dictionary mapping worker ID/address to dictionary of information about
that worker for each retired worker.
If there are keys that exist in memory only on the workers being retired and it
was impossible to replicate them somewhere else (e.g. because there aren't
any other running workers), the workers holding such keys won't be retired and
won't appear in the returned dict.
See Also
--------
Scheduler.workers_to_close
"""
if names is not None and workers is not None:
raise TypeError("names and workers are mutually exclusive")
if (names is not None or workers is not None) and kwargs:
raise TypeError(
"Parameters for workers_to_close() are mutually exclusive with "
f"names and workers: {kwargs}"
)
stimulus_id = stimulus_id or f"retire-workers-{time()}"
# This lock makes retire_workers, rebalance, and replicate mutually
# exclusive and will no longer be necessary once rebalance and replicate are
# migrated to the Active Memory Manager.
# However, it allows multiple instances of retire_workers to run in parallel.
async with self._replica_lock("retire-workers"):
if names is not None:
logger.info("Retire worker names %s", names)
# Support cases where names are passed through a CLI and become strings
names_set = {str(name) for name in names}
wss = {ws for ws in self.workers.values() if str(ws.name) in names_set}
elif workers is not None:
logger.info(
"Retire worker addresses (stimulus_id='%s') %s",
stimulus_id,
workers,
)
wss = {
self.workers[address]
for address in workers
if address in self.workers
}
else:
wss = {
self.workers[address] for address in self.workers_to_close(**kwargs)
}
if not wss:
return []
stop_amm = False
amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm")
if not amm or not amm.running:
amm = ActiveMemoryManagerExtension(
self, policies=set(), register=False, start=True, interval=2.0
)
stop_amm = True
try:
coros = []
for ws in wss:
policy = RetireWorker(ws.address)
amm.add_policy(policy)
# Change Worker.status to closing_gracefully. Immediately set
# the same on the scheduler to prevent race conditions.
prev_status = ws.status
self.handle_worker_status_change(
Status.closing_gracefully, ws, stimulus_id
)
# FIXME: We should send a message to the nanny first;
# eventually workers won't be able to close their own nannies.
self.stream_comms[ws.address].send(
{
"op": "worker-status-change",
"status": ws.status.name,
"stimulus_id": stimulus_id,
}
)
coros.append(
self._track_retire_worker(
ws,
policy,
prev_status=prev_status,
close=close_workers,
remove=remove,
stimulus_id=stimulus_id,
)
)
# Give the AMM a kick, in addition to its periodic running. This is
# to avoid unnecessarily waiting for a potentially arbitrarily long
# time (depending on interval settings)
amm.run_once()
workers_info_ok = []
workers_info_abort = []
for addr, result in await asyncio.gather(*coros):
if result == "OK":
workers_info_ok.append(addr)
else:
workers_info_abort.append(addr)
finally:
if stop_amm:
amm.stop()
self.log_event(
"all",
{
"action": "retire-workers",
"retired": list(workers_info_ok),
"could-not-retire": list(workers_info_abort),
"stimulus_id": stimulus_id,
},
)
self.log_event(
list(workers_info_ok),
{"action": "retired", "stimulus_id": stimulus_id},
)
self.log_event(
list(workers_info_abort),
{"action": "could-not-retire", "stimulus_id": stimulus_id},
)
return workers_info_ok
async def _track_retire_worker(
self,
ws: WorkerState,
policy: RetireWorker,
prev_status: Status,
close: bool,
remove: bool,
stimulus_id: str,
) -> tuple[str, Literal["OK", "no-recipients"]]:
while not policy.done():
# Sleep 0.01s when there are 4 tasks or less
# Sleep 0.5s when there are 200 or more
poll_interval = max(0.01, min(0.5, len(ws.has_what) / 400))
await asyncio.sleep(poll_interval)
if policy.no_recipients:
# Abort retirement. This time we don't need to worry about race
# conditions and we can wait for a scheduler->worker->scheduler
# round-trip.
self.stream_comms[ws.address].send(
{
"op": "worker-status-change",
"status": prev_status.name,
"stimulus_id": stimulus_id,
}
)
logger.warning(
f"Could not retire worker {ws.address!r}: unique data could not be "
f"moved to any other worker ({stimulus_id=!r})"
)
return ws.address, "no-recipients"
logger.debug(
f"All unique keys on worker {ws.address!r} have been replicated elsewhere"
)
if remove:
await self.remove_worker(
ws.address, expected=True, close=close, stimulus_id=stimulus_id
)
elif close:
self.close_worker(ws.address)
logger.info(f"Retired worker {ws.address!r} ({stimulus_id=!r})")
return ws.address, "OK"
def add_keys(
self, worker: str, keys: Collection[Key] = (), stimulus_id: str | None = None
) -> Literal["OK", "not found"]:
"""
Learn that a worker has certain keys
This should not be used in practice and is mostly here for legacy
reasons. However, it is sent by workers from time to time.
"""
if worker not in self.workers:
return "not found"
ws = self.workers[worker]
redundant_replicas = []
for key in keys:
ts = self.tasks.get(key)
if ts is not None and ts.state == "memory":
self.add_replica(ts, ws)
else:
redundant_replicas.append(key)
if redundant_replicas:
if not stimulus_id:
stimulus_id = f"redundant-replicas-{time()}"
self.worker_send(
worker,
{
"op": "remove-replicas",
"keys": redundant_replicas,
"stimulus_id": stimulus_id,
},
)
return "OK"
@log_errors
def update_data(
self,
*,
who_has: dict[Key, list[str]],
nbytes: dict[Key, int],
client: str | None = None,
) -> None:
"""Learn that new data has entered the network from an external source"""
who_has = {k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items()}
logger.debug("Update data %s", who_has)
for key, workers in who_has.items():
ts = self.tasks.get(key)
if ts is None:
ts = self.new_task(key, None, "memory")
ts.state = "memory"
ts_nbytes = nbytes.get(key, -1)
if ts_nbytes >= 0:
ts.set_nbytes(ts_nbytes)
for w in workers:
ws = self.workers[w]
self.add_replica(ts, ws)
self.report({"op": "key-in-memory", "key": key, "workers": list(workers)})
if client:
self.client_desires_keys(keys=list(who_has), client=client)
@overload
def report_on_key(self, key: Key, *, client: str | None = None) -> None: ...
@overload
def report_on_key(self, *, ts: TaskState, client: str | None = None) -> None: ...
def report_on_key(self, key=None, *, ts=None, client=None):
if (ts is None) == (key is None):
raise ValueError( # pragma: nocover
f"ts and key are mutually exclusive; received {key=!r}, {ts=!r}"
)
if ts is None:
assert key is not None
ts = self.tasks.get(key)
else:
key = ts.key
if ts is not None:
report_msg = _task_to_report_msg(ts)
else:
report_msg = {"op": "cancelled-keys", "keys": [key]}
if report_msg is not None:
self.report(report_msg, ts=ts, client=client)
@log_errors
async def feed(
self,
comm: Comm,
function: bytes | None = None,
setup: bytes | None = None,
teardown: bytes | None = None,
interval: str | float = "1s",
**kwargs: Any,
) -> None:
"""
Provides a data Comm to external requester
Caution: this runs arbitrary Python code on the scheduler. This should
eventually be phased out. It is mostly used by diagnostics.
"""
interval = parse_timedelta(interval)
if function:
function = pickle.loads(function)
if setup:
setup = pickle.loads(setup)
if teardown:
teardown = pickle.loads(teardown)
state = setup(self) if setup else None # type: ignore
if inspect.isawaitable(state):
state = await state
try:
while self.status == Status.running:
if state is None:
response = function(self) # type: ignore
else:
response = function(self, state) # type: ignore
await comm.write(response)
await asyncio.sleep(interval)
except OSError:
pass
finally:
if teardown:
teardown(self, state) # type: ignore
def log_worker_event(
self, worker: str, topic: str | Collection[str], msg: Any
) -> None:
if isinstance(msg, dict) and worker != topic:
msg["worker"] = worker
self.log_event(topic, msg)
def subscribe_worker_status(self, comm: Comm) -> dict[str, Any]:
WorkerStatusPlugin(self, comm)
ident = self.identity()
for v in ident["workers"].values():
del v["metrics"]
del v["last_seen"]
return ident
def get_processing(
self, workers: Iterable[str] | None = None
) -> dict[str, list[Key]]:
if workers is not None:
workers = set(map(self.coerce_address, workers))
return {w: [ts.key for ts in self.workers[w].processing] for w in workers}
else:
return {
w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
}
def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]:
if keys is not None:
return {
key: (
[ws.address for ws in self.tasks[key].who_has or ()]
if key in self.tasks
else []
)
for key in keys
}
else:
return {
key: [ws.address for ws in ts.who_has or ()]
for key, ts in self.tasks.items()
}
def get_has_what(
self, workers: Iterable[str] | None = None
) -> dict[str, list[Key]]:
if workers is not None:
workers = map(self.coerce_address, workers)
return {
w: (
[ts.key for ts in self.workers[w].has_what]
if w in self.workers
else []
)
for w in workers
}
else:
return {w: [ts.key for ts in ws.has_what] for w, ws in self.workers.items()}
def get_ncores(self, workers: Iterable[str] | None = None) -> dict[str, int]:
if workers is not None:
workers = map(self.coerce_address, workers)
return {w: self.workers[w].nthreads for w in workers if w in self.workers}
else:
return {w: ws.nthreads for w, ws in self.workers.items()}
def get_ncores_running(
self, workers: Iterable[str] | None = None
) -> dict[str, int]:
ncores = self.get_ncores(workers=workers)
return {
w: n for w, n in ncores.items() if self.workers[w].status == Status.running
}
async def get_call_stack(self, keys: Iterable[Key] | None = None) -> dict[str, Any]:
workers: dict[str, list[Key] | None]
if keys is not None:
stack = list(keys)
processing = set()
while stack:
key = stack.pop()
> ts = self.tasks[key]
E KeyError: 'slowinc-1ea6f50d354d90f0a53d5f5250e2c25f'
distributed\scheduler.py:7903: KeyError
Check warning on line 0 in distributed.tests.test_client
github-actions / Unit Test Results
11 out of 12 runs failed: test_release_persisted_collection (distributed.tests.test_client)
artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-ci1/pytest.xml [took 0s]
Raw output
IndexError: tuple index out of range
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36527', workers: 0, cores: 0, tasks: 0>
a = <Worker 'tcp://127.0.0.1:46067', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>
b = <Worker 'tcp://127.0.0.1:38837', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
@gen_cluster(client=True)
async def test_release_persisted_collection(c, s, a, b):
np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
arr = c.persist(da.random.random((10,), chunks=(10,)))
await wait(arr)
_release_persisted(arr)
while s.tasks:
await asyncio.sleep(0.01)
with pytest.raises(CancelledError):
> await c.compute(arr)
distributed/tests/test_client.py:8235:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:1328: in finalize
return concatenate3(results)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5496: in concatenate3
chunks = chunks_from_arrays(arrays)
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: in chunks_from_arrays
result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import contextlib
import math
import operator
import os
import pickle
import re
import sys
import traceback
import uuid
import warnings
from bisect import bisect
from collections import defaultdict
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial, reduce, wraps
from itertools import product, zip_longest
from numbers import Integral, Number
from operator import add, mul
from threading import Lock
from typing import Any, Literal, TypeVar, Union, cast
import numpy as np
from numpy.typing import ArrayLike
from packaging.version import Version
from tlz import accumulate, concat, first, groupby, partition
from tlz.curried import pluck
from toolz import frequencies
from dask import compute, config, core
from dask.array import chunk
from dask.array.chunk import getitem
from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
# Keep einsum_lookup and tensordot_lookup here for backwards compatibility
from dask.array.dispatch import ( # noqa: F401
concatenate_lookup,
einsum_lookup,
tensordot_lookup,
)
from dask.array.numpy_compat import NUMPY_GE_200, _Recurser
from dask.array.slicing import replace_ellipsis, setitem_array, slice_array
from dask.array.utils import compute_meta, meta_from_array
from dask.base import (
DaskMethodsMixin,
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
from dask.blockwise import blockwise as core_blockwise
from dask.blockwise import broadcast_dimensions
from dask.context import globalmethod
from dask.core import quote
from dask.delayed import Delayed, delayed
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.layers import ArraySliceDep, reshapelist
from dask.sizeof import sizeof
from dask.typing import Graph, Key, NestedKeys
from dask.utils import (
IndexCallable,
SerializableLock,
cached_cumsum,
cached_property,
concrete,
derived_from,
format_bytes,
funcname,
has_keyword,
is_arraylike,
is_dataframe_like,
is_index_like,
is_integer,
is_series_like,
maybe_pluralize,
ndeepmap,
ndimlist,
parse_bytes,
typename,
)
from dask.widgets import get_template
T_IntOrNaN = Union[int, float] # Should be Union[int, Literal[np.nan]]
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])
unknown_chunk_message = (
"\n\n"
"A possible solution: "
"https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks\n"
"Summary: to compute chunks sizes, use\n\n"
" x.compute_chunk_sizes() # for Dask Array `x`\n"
" ddf.to_dask_array(lengths=True) # for Dask DataFrame `ddf`"
)
class PerformanceWarning(Warning):
"""A warning given when bad chunking may cause poor performance"""
def getter(a, b, asarray=True, lock=None):
if isinstance(b, tuple) and any(x is None for x in b):
b2 = tuple(x for x in b if x is not None)
b3 = tuple(
None if x is None else slice(None, None)
for x in b
if not isinstance(x, Integral)
)
return getter(a, b2, asarray=asarray, lock=lock)[b3]
if lock:
lock.acquire()
try:
c = a[b]
# Below we special-case `np.matrix` to force a conversion to
# `np.ndarray` and preserve original Dask behavior for `getter`,
# as for all purposes `np.matrix` is array-like and thus
# `is_arraylike` evaluates to `True` in that case.
if asarray and (not is_arraylike(c) or isinstance(c, np.matrix)):
c = np.asarray(c)
finally:
if lock:
lock.release()
return c
def getter_nofancy(a, b, asarray=True, lock=None):
"""A simple wrapper around ``getter``.
Used to indicate to the optimization passes that the backend doesn't
support fancy indexing.
"""
return getter(a, b, asarray=asarray, lock=lock)
def getter_inline(a, b, asarray=True, lock=None):
"""A getter function that optimizations feel comfortable inlining
Slicing operations with this function may be inlined into a graph, such as
in the following rewrite
**Before**
>>> a = x[:10] # doctest: +SKIP
>>> b = a + 1 # doctest: +SKIP
>>> c = a * 2 # doctest: +SKIP
**After**
>>> b = x[:10] + 1 # doctest: +SKIP
>>> c = x[:10] * 2 # doctest: +SKIP
This inlining can be relevant to operations when running off of disk.
"""
return getter(a, b, asarray=asarray, lock=lock)
from dask.array.optimization import fuse_slice, optimize
# __array_function__ dict for mapping aliases and mismatching names
_HANDLED_FUNCTIONS = {}
def implements(*numpy_functions):
"""Register an __array_function__ implementation for dask.array.Array
Register that a function implements the API of a NumPy function (or several
NumPy functions in case of aliases) which is handled with
``__array_function__``.
Parameters
----------
\\*numpy_functions : callables
One or more NumPy functions that are handled by ``__array_function__``
and will be mapped by `implements` to a `dask.array` function.
"""
def decorator(dask_func):
for numpy_function in numpy_functions:
_HANDLED_FUNCTIONS[numpy_function] = dask_func
return dask_func
return decorator
def _should_delegate(self, other) -> bool:
"""Check whether Dask should delegate to the other.
This implementation follows NEP-13:
https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
if hasattr(other, "__array_ufunc__") and other.__array_ufunc__ is None:
return True
elif (
hasattr(other, "__array_ufunc__")
and not is_valid_array_chunk(other)
# don't delegate to our own parent classes
and not isinstance(self, type(other))
and type(self) is not type(other)
):
return True
return False
def check_if_handled_given_other(f):
"""Check if method is handled by Dask given type of other
Ensures proper deferral to upcast types in dunder operations without
assuming unknown types are automatically downcast types.
"""
@wraps(f)
def wrapper(self, other):
if _should_delegate(self, other):
return NotImplemented
else:
return f(self, other)
return wrapper
def slices_from_chunks(chunks):
"""Translate chunks tuple to a set of slices in product order
>>> slices_from_chunks(((2, 2), (3, 3, 3))) # doctest: +NORMALIZE_WHITESPACE
[(slice(0, 2, None), slice(0, 3, None)),
(slice(0, 2, None), slice(3, 6, None)),
(slice(0, 2, None), slice(6, 9, None)),
(slice(2, 4, None), slice(0, 3, None)),
(slice(2, 4, None), slice(3, 6, None)),
(slice(2, 4, None), slice(6, 9, None))]
"""
cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
slices = [
[slice(s, s + dim) for s, dim in zip(starts, shapes)]
for starts, shapes in zip(cumdims, chunks)
]
return list(product(*slices))
def graph_from_arraylike(
arr, # Any array-like which supports slicing
chunks,
shape,
name,
getitem=getter,
lock=False,
asarray=True,
dtype=None,
inline_array=False,
) -> HighLevelGraph:
"""
HighLevelGraph for slicing chunks from an array-like according to a chunk pattern.
If ``inline_array`` is True, this make a Blockwise layer of slicing tasks where the
array-like is embedded into every task.,
If ``inline_array`` is False, this inserts the array-like as a standalone value in
a MaterializedLayer, then generates a Blockwise layer of slicing tasks that refer
to it.
>>> dict(graph_from_arraylike(arr, chunks=(2, 3), shape=(4, 6), name="X", inline_array=True)) # doctest: +SKIP
{(arr, 0, 0): (getter, arr, (slice(0, 2), slice(0, 3))),
(arr, 1, 0): (getter, arr, (slice(2, 4), slice(0, 3))),
(arr, 1, 1): (getter, arr, (slice(2, 4), slice(3, 6))),
(arr, 0, 1): (getter, arr, (slice(0, 2), slice(3, 6)))}
>>> dict( # doctest: +SKIP
graph_from_arraylike(arr, chunks=((2, 2), (3, 3)), shape=(4,6), name="X", inline_array=False)
)
{"original-X": arr,
('X', 0, 0): (getter, 'original-X', (slice(0, 2), slice(0, 3))),
('X', 1, 0): (getter, 'original-X', (slice(2, 4), slice(0, 3))),
('X', 1, 1): (getter, 'original-X', (slice(2, 4), slice(3, 6))),
('X', 0, 1): (getter, 'original-X', (slice(0, 2), slice(3, 6)))}
"""
chunks = normalize_chunks(chunks, shape, dtype=dtype)
out_ind = tuple(range(len(shape)))
if (
has_keyword(getitem, "asarray")
and has_keyword(getitem, "lock")
and (not asarray or lock)
):
kwargs = {"asarray": asarray, "lock": lock}
else:
# Common case, drop extra parameters
kwargs = {}
if inline_array:
layer = core_blockwise(
getitem,
name,
out_ind,
arr,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
return HighLevelGraph.from_collections(name, layer)
else:
original_name = "original-" + name
layers = {}
layers[original_name] = MaterializedLayer({original_name: arr})
layers[name] = core_blockwise(
getitem,
name,
out_ind,
original_name,
None,
ArraySliceDep(chunks),
out_ind,
numblocks={},
**kwargs,
)
deps = {
original_name: set(),
name: {original_name},
}
return HighLevelGraph(layers, deps)
def dotmany(A, B, leftfunc=None, rightfunc=None, **kwargs):
"""Dot product of many aligned chunks
>>> x = np.array([[1, 2], [1, 2]])
>>> y = np.array([[10, 20], [10, 20]])
>>> dotmany([x, x, x], [y, y, y])
array([[ 90, 180],
[ 90, 180]])
Optionally pass in functions to apply to the left and right chunks
>>> dotmany([x, x, x], [y, y, y], rightfunc=np.transpose)
array([[150, 150],
[150, 150]])
"""
if leftfunc:
A = map(leftfunc, A)
if rightfunc:
B = map(rightfunc, B)
return sum(map(partial(np.dot, **kwargs), A, B))
def _concatenate2(arrays, axes=None):
"""Recursively concatenate nested lists of arrays along axes
Each entry in axes corresponds to each level of the nested list. The
length of axes should correspond to the level of nesting of arrays.
If axes is an empty list or tuple, return arrays, or arrays[0] if
arrays is a list.
>>> x = np.array([[1, 2], [3, 4]])
>>> _concatenate2([x, x], axes=[0])
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
>>> _concatenate2([x, x], axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
>>> _concatenate2([[x, x], [x, x]], axes=[0, 1])
array([[1, 2, 1, 2],
[3, 4, 3, 4],
[1, 2, 1, 2],
[3, 4, 3, 4]])
Supports Iterators
>>> _concatenate2(iter([x, x]), axes=[1])
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
Special Case
>>> _concatenate2([x, x], axes=())
array([[1, 2],
[3, 4]])
"""
if axes is None:
axes = []
if axes == ():
if isinstance(arrays, list):
return arrays[0]
else:
return arrays
if isinstance(arrays, Iterator):
arrays = list(arrays)
if not isinstance(arrays, (list, tuple)):
return arrays
if len(axes) > 1:
arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
concatenate = concatenate_lookup.dispatch(
type(max(arrays, key=lambda x: getattr(x, "__array_priority__", 0)))
)
if isinstance(arrays[0], dict):
# Handle concatenation of `dict`s, used as a replacement for structured
# arrays when that's not supported by the array library (e.g., CuPy).
keys = list(arrays[0].keys())
assert all(list(a.keys()) == keys for a in arrays)
ret = dict()
for k in keys:
ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0])
return ret
else:
return concatenate(arrays, axis=axes[0])
def apply_infer_dtype(func, args, kwargs, funcname, suggest_dtype="dtype", nout=None):
"""
Tries to infer output dtype of ``func`` for a small set of input arguments.
Parameters
----------
func: Callable
Function for which output dtype is to be determined
args: List of array like
Arguments to the function, which would usually be used. Only attributes
``ndim`` and ``dtype`` are used.
kwargs: dict
Additional ``kwargs`` to the ``func``
funcname: String
Name of calling function to improve potential error messages
suggest_dtype: None/False or String
If not ``None`` adds suggestion to potential error message to specify a dtype
via the specified kwarg. Defaults to ``'dtype'``.
nout: None or Int
``None`` if function returns single output, integer if many.
Defaults to ``None``.
Returns
-------
: dtype or List of dtype
One or many dtypes (depending on ``nout``)
"""
from dask.array.utils import meta_from_array
# make sure that every arg is an evaluated array
args = [
(
np.ones_like(meta_from_array(x), shape=((1,) * x.ndim), dtype=x.dtype)
if is_arraylike(x)
else x
)
for x in args
]
try:
with np.errstate(all="ignore"):
o = func(*args, **kwargs)
except Exception as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb = "".join(traceback.format_tb(exc_traceback))
suggest = (
(
"Please specify the dtype explicitly using the "
"`{dtype}` kwarg.\n\n".format(dtype=suggest_dtype)
)
if suggest_dtype
else ""
)
msg = (
f"`dtype` inference failed in `{funcname}`.\n\n"
f"{suggest}"
"Original error is below:\n"
"------------------------\n"
f"{e!r}\n\n"
"Traceback:\n"
"---------\n"
f"{tb}"
)
else:
msg = None
if msg is not None:
raise ValueError(msg)
return getattr(o, "dtype", type(o)) if nout is None else tuple(e.dtype for e in o)
def normalize_arg(x):
"""Normalize user provided arguments to blockwise or map_blocks
We do a few things:
1. If they are string literals that might collide with blockwise_token then we
quote them
2. IF they are large (as defined by sizeof) then we put them into the
graph on their own by using dask.delayed
"""
if is_dask_collection(x):
return x
elif isinstance(x, str) and re.match(r"_\d+", x):
return delayed(x)
elif isinstance(x, list) and len(x) >= 10:
return delayed(x)
elif sizeof(x) > 1e6:
return delayed(x)
else:
return x
def _pass_extra_kwargs(func, keys, *args, **kwargs):
"""Helper for :func:`dask.array.map_blocks` to pass `block_info` or `block_id`.
For each element of `keys`, a corresponding element of args is changed
to a keyword argument with that key, before all arguments re passed on
to `func`.
"""
kwargs.update(zip(keys, args))
return func(*args[len(keys) :], **kwargs)
def map_blocks(
func,
*args,
name=None,
token=None,
dtype=None,
chunks=None,
drop_axis=None,
new_axis=None,
enforce_ndim=False,
meta=None,
**kwargs,
):
"""Map a function across all blocks of a dask array.
Note that ``map_blocks`` will attempt to automatically determine the output
array type by calling ``func`` on 0-d versions of the inputs. Please refer to
the ``meta`` keyword argument below if you expect that the function will not
succeed when operating on 0-d arrays.
Parameters
----------
func : callable
Function to apply to every block in the array.
If ``func`` accepts ``block_info=`` or ``block_id=``
as keyword arguments, these will be passed dictionaries
containing information about input and output chunks/arrays
during computation. See examples for details.
args : dask arrays or other objects
dtype : np.dtype, optional
The ``dtype`` of the output array. It is recommended to provide this.
If not provided, will be inferred by applying the function to a small
set of fake data.
chunks : tuple, optional
Chunk shape of resulting blocks if the function does not preserve
shape. If not provided, the resulting array is assumed to have the same
block structure as the first input array.
drop_axis : number or iterable, optional
Dimensions lost by the function.
new_axis : number or iterable, optional
New dimensions created by the function. Note that these are applied
after ``drop_axis`` (if present). The size of each chunk along this
dimension will be set to 1. Please specify ``chunks`` if the individual
chunks have a different size.
enforce_ndim : bool, default False
Whether to enforce at runtime that the dimensionality of the array
produced by ``func`` actually matches that of the array returned by
``map_blocks``.
If True, this will raise an error when there is a mismatch.
token : string, optional
The key prefix to use for the output array. If not provided, will be
determined from the function name.
name : string, optional
The key name to use for the output array. Note that this fully
specifies the output key name, and must be unique. If not provided,
will be determined by a hash of the arguments.
meta : array-like, optional
The ``meta`` of the output array, when specified is expected to be an
array of the same type and dtype of that returned when calling ``.compute()``
on the array returned by this function. When not provided, ``meta`` will be
inferred by applying the function to a small set of fake data, usually a
0-d array. It's important to ensure that ``func`` can successfully complete
computation without raising exceptions when 0-d is passed to it, providing
``meta`` will be required otherwise. If the output type is known beforehand
(e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
**kwargs :
Other keyword arguments to pass to function. Values must be constants
(not dask.arrays)
See Also
--------
dask.array.map_overlap : Generalized operation with overlap between neighbors.
dask.array.blockwise : Generalized operation with control over block alignment.
Examples
--------
>>> import dask.array as da
>>> x = da.arange(6, chunks=3)
>>> x.map_blocks(lambda x: x * 2).compute()
array([ 0, 2, 4, 6, 8, 10])
The ``da.map_blocks`` function can also accept multiple arrays.
>>> d = da.arange(5, chunks=2)
>>> e = da.arange(5, chunks=2)
>>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
>>> f.compute()
array([ 0, 2, 6, 12, 20])
If the function changes shape of the blocks then you must provide chunks
explicitly.
>>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
You have a bit of freedom in specifying chunks. If all of the output chunk
sizes are the same, you can provide just that chunk size as a single tuple.
>>> a = da.arange(18, chunks=(6,))
>>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
If the function changes the dimension of the blocks you must specify the
created or destroyed dimensions.
>>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
... new_axis=[0, 2])
If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
add the necessary number of axes on the left.
Note that ``map_blocks()`` will concatenate chunks along axes specified by
the keyword parameter ``drop_axis`` prior to applying the function.
This is illustrated in the figure below:
.. image:: /images/map_blocks_drop_axis.png
Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
on an axis that is chunked. In that case, it is better not to use
``map_blocks`` but rather
``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
maintains a leaner memory footprint while it drops any axis.
Map_blocks aligns blocks by block positions without regard to shape. In the
following example we have two arrays with the same number of blocks but
with different shape and chunk sizes.
>>> x = da.arange(1000, chunks=(100,))
>>> y = da.arange(100, chunks=(10,))
The relevant attribute to match is numblocks.
>>> x.numblocks
(10,)
>>> y.numblocks
(10,)
If these match (up to broadcasting rules) then we can map arbitrary
functions across blocks
>>> def func(a, b):
... return np.array([a.max(), b.max()])
>>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
>>> _.compute()
array([ 99, 9, 199, 19, 299, 29, 399, 39, 499, 49, 599, 59, 699,
69, 799, 79, 899, 89, 999, 99])
Your block function can get information about where it is in the array by
accepting a special ``block_info`` or ``block_id`` keyword argument.
During computation, they will contain information about each of the input
and output chunks (and dask arrays) relevant to each call of ``func``.
>>> def func(block_info=None):
... pass
This will receive the following information:
>>> block_info # doctest: +SKIP
{0: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)]},
None: {'shape': (1000,),
'num-chunks': (10,),
'chunk-location': (4,),
'array-location': [(400, 500)],
'chunk-shape': (100,),
'dtype': dtype('float64')}}
The keys to the ``block_info`` dictionary indicate which is the input and
output Dask array:
- **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
The dictionary key is ``0`` because that is the argument index corresponding
to the first input Dask array.
In cases where multiple Dask arrays have been passed as input to the function,
you can access them with the number corresponding to the input argument,
eg: ``block_info[1]``, ``block_info[2]``, etc.
(Note that if you pass multiple Dask arrays as input to map_blocks,
the arrays must match each other by having matching numbers of chunks,
along corresponding dimensions up to broadcasting rules.)
- **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
and contains information about the output chunks.
The output chunk shape and dtype may may be different than the input chunks.
For each dask array, ``block_info`` describes:
- ``shape``: the shape of the full Dask array,
- ``num-chunks``: the number of chunks of the full array in each dimension,
- ``chunk-location``: the chunk location (for example the fourth chunk over
in the first dimension), and
- ``array-location``: the array location within the full Dask array
(for example the slice corresponding to ``40:50``).
In addition to these, there are two extra parameters described by
``block_info`` for the output array (in ``block_info[None]``):
- ``chunk-shape``: the output chunk shape, and
- ``dtype``: the output dtype.
These features can be combined to synthesize an array from scratch, for
example:
>>> def func(block_info=None):
... loc = block_info[None]['array-location'][0]
... return np.arange(loc[0], loc[1])
>>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
>>> _.compute()
array([0, 1, 2, 3, 4, 5, 6, 7])
``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
>>> def func(block_id=None):
... pass
This will receive the following information:
>>> block_id # doctest: +SKIP
(4, 3)
You may specify the key name prefix of the resulting task in the graph with
the optional ``token`` keyword argument.
>>> x.map_blocks(lambda x: x + 1, name='increment')
dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
For functions that may not handle 0-d arrays, it's also possible to specify
``meta`` with an empty array matching the type of the expected result. In
the example below, ``func`` will result in an ``IndexError`` when computing
``meta``:
>>> rng = da.random.default_rng()
>>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
a ``dtype``:
>>> import cupy # doctest: +SKIP
>>> rng = da.random.default_rng(cupy.random.default_rng()) # doctest: +SKIP
>>> dt = np.float32
>>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt)) # doctest: +SKIP
dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
"""
if drop_axis is None:
drop_axis = []
if not callable(func):
msg = (
"First argument must be callable function, not %s\n"
"Usage: da.map_blocks(function, x)\n"
" or: da.map_blocks(function, x, y, z)"
)
raise TypeError(msg % type(func).__name__)
if token:
warnings.warn(
"The `token=` keyword to `map_blocks` has been moved to `name=`. "
"Please use `name=` instead as the `token=` keyword will be removed "
"in a future release.",
category=FutureWarning,
)
name = token
name = f"{name or funcname(func)}-{tokenize(func, dtype, chunks, drop_axis, new_axis, *args, **kwargs)}"
new_axes = {}
if isinstance(drop_axis, Number):
drop_axis = [drop_axis]
if isinstance(new_axis, Number):
new_axis = [new_axis] # TODO: handle new_axis
arrs = [a for a in args if isinstance(a, Array)]
argpa…tack
"""
from dask.array import wrap
seq = [asarray(a, allow_unknown_chunksizes=allow_unknown_chunksizes) for a in seq]
if not seq:
raise ValueError("Need array(s) to concatenate")
if axis is None:
seq = [a.flatten() for a in seq]
axis = 0
seq_metas = [meta_from_array(s) for s in seq]
_concatenate = concatenate_lookup.dispatch(
type(max(seq_metas, key=lambda x: getattr(x, "__array_priority__", 0)))
)
meta = _concatenate(seq_metas, axis=axis)
# Promote types to match meta
seq = [a.astype(meta.dtype) for a in seq]
# Find output array shape
ndim = len(seq[0].shape)
shape = tuple(
sum(a.shape[i] for a in seq) if i == axis else seq[0].shape[i]
for i in range(ndim)
)
# Drop empty arrays
seq2 = [a for a in seq if a.size]
if not seq2:
seq2 = seq
if axis < 0:
axis = ndim + axis
if axis >= ndim:
msg = (
"Axis must be less than than number of dimensions"
"\nData has %d dimensions, but got axis=%d"
)
raise ValueError(msg % (ndim, axis))
n = len(seq2)
if n == 0:
try:
return wrap.empty_like(meta, shape=shape, chunks=shape, dtype=meta.dtype)
except TypeError:
return wrap.empty(shape, chunks=shape, dtype=meta.dtype)
elif n == 1:
return seq2[0]
if not allow_unknown_chunksizes and not all(
i == axis or all(x.shape[i] == seq2[0].shape[i] for x in seq2)
for i in range(ndim)
):
if any(map(np.isnan, seq2[0].shape)):
raise ValueError(
"Tried to concatenate arrays with unknown"
" shape %s.\n\nTwo solutions:\n"
" 1. Force concatenation pass"
" allow_unknown_chunksizes=True.\n"
" 2. Compute shapes with "
"[x.compute_chunk_sizes() for x in seq]" % str(seq2[0].shape)
)
raise ValueError("Shapes do not align: %s", [x.shape for x in seq2])
inds = [list(range(ndim)) for i in range(n)]
for i, ind in enumerate(inds):
ind[axis] = -(i + 1)
uc_args = list(concat(zip(seq2, inds)))
_, seq2 = unify_chunks(*uc_args, warn=False)
bds = [a.chunks for a in seq2]
chunks = (
seq2[0].chunks[:axis]
+ (sum((bd[axis] for bd in bds), ()),)
+ seq2[0].chunks[axis + 1 :]
)
cum_dims = [0] + list(accumulate(add, [len(a.chunks[axis]) for a in seq2]))
names = [a.name for a in seq2]
name = "concatenate-" + tokenize(names, axis)
keys = list(product([name], *[range(len(bd)) for bd in chunks]))
values = [
(names[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[1 : axis + 1]
+ (key[axis + 1] - cum_dims[bisect(cum_dims, key[axis + 1]) - 1],)
+ key[axis + 2 :]
for key in keys
]
dsk = dict(zip(keys, values))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq2)
return Array(graph, name, chunks, meta=meta)
def load_store_chunk(
x: Any,
out: Any,
index: slice,
lock: Any,
return_stored: bool,
load_stored: bool,
):
"""
A function inserted in a Dask graph for storing a chunk.
Parameters
----------
x: array-like
An array (potentially a NumPy one)
out: array-like
Where to store results.
index: slice-like
Where to store result from ``x`` in ``out``.
lock: Lock-like or False
Lock to use before writing to ``out``.
return_stored: bool
Whether to return ``out``.
load_stored: bool
Whether to return the array stored in ``out``.
Ignored if ``return_stored`` is not ``True``.
Returns
-------
If return_stored=True and load_stored=False
out
If return_stored=True and load_stored=True
out[index]
If return_stored=False and compute=False
None
Examples
--------
>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> load_store_chunk(a, b, (slice(None), slice(None)), False, False, False)
"""
if lock:
lock.acquire()
try:
if x is not None and x.size != 0:
if is_arraylike(x):
out[index] = x
else:
out[index] = np.asanyarray(x)
if return_stored and load_stored:
return out[index]
elif return_stored and not load_stored:
return out
else:
return None
finally:
if lock:
lock.release()
def store_chunk(
x: ArrayLike, out: ArrayLike, index: slice, lock: Any, return_stored: bool
):
return load_store_chunk(x, out, index, lock, return_stored, False)
A = TypeVar("A", bound=ArrayLike)
def load_chunk(out: A, index: slice, lock: Any) -> A:
return load_store_chunk(None, out, index, lock, True, True)
def insert_to_ooc(
keys: list,
chunks: tuple[tuple[int, ...], ...],
out: ArrayLike,
name: str,
*,
lock: Lock | bool = True,
region: tuple[slice, ...] | slice | None = None,
return_stored: bool = False,
load_stored: bool = False,
) -> dict:
"""
Creates a Dask graph for storing chunks from ``arr`` in ``out``.
Parameters
----------
keys: list
Dask keys of the input array
chunks: tuple
Dask chunks of the input array
out: array-like
Where to store results to
name: str
First element of dask keys
lock: Lock-like or bool, optional
Whether to lock or with what (default is ``True``,
which means a :class:`threading.Lock` instance).
region: slice-like, optional
Where in ``out`` to store ``arr``'s results
(default is ``None``, meaning all of ``out``).
return_stored: bool, optional
Whether to return ``out``
(default is ``False``, meaning ``None`` is returned).
load_stored: bool, optional
Whether to handling loading from ``out`` at the same time.
Ignored if ``return_stored`` is not ``True``.
(default is ``False``, meaning defer to ``return_stored``).
Returns
-------
dask graph of store operation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123") # doctest: +SKIP
"""
if lock is True:
lock = Lock()
slices = slices_from_chunks(chunks)
if region:
slices = [fuse_slice(region, slc) for slc in slices]
if return_stored and load_stored:
func = load_store_chunk
args = (load_stored,)
else:
func = store_chunk # type: ignore
args = () # type: ignore
dsk = {
(name,) + t[1:]: (func, t, out, slc, lock, return_stored) + args
for t, slc in zip(core.flatten(keys), slices)
}
return dsk
def retrieve_from_ooc(
keys: Collection[Key], dsk_pre: Graph, dsk_post: Graph
) -> dict[tuple, Any]:
"""
Creates a Dask graph for loading stored ``keys`` from ``dsk``.
Parameters
----------
keys: Collection
A sequence containing Dask graph keys to load
dsk_pre: Mapping
A Dask graph corresponding to a Dask Array before computation
dsk_post: Mapping
A Dask graph corresponding to a Dask Array after computation
Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> g = insert_to_ooc(d.__dask_keys__(), d.chunks, a, "store-123")
>>> retrieve_from_ooc(g.keys(), g, {k: k for k in g.keys()}) # doctest: +SKIP
"""
load_dsk = {
("load-" + k[0],) + k[1:]: (load_chunk, dsk_post[k]) + dsk_pre[k][3:-1] # type: ignore
for k in keys
}
return load_dsk
def _as_dtype(a, dtype):
if dtype is None:
return a
else:
return a.astype(dtype)
def asarray(
a, allow_unknown_chunksizes=False, dtype=None, order=None, *, like=None, **kwargs
):
"""Convert the input to a dask array.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
allow_unknown_chunksizes: bool
Allow unknown chunksizes, such as come from converting from dask
dataframes. Dask.array is unable to verify that chunks line up. If
data comes from differently aligned sources then this can cause
unexpected results.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(
stack(a, allow_unknown_chunksizes=allow_unknown_chunksizes), dtype
)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(a, getitem=getter_inline, **kwargs)
def asanyarray(a, dtype=None, order=None, *, like=None, inline_array=False):
"""Convert the input to a dask array.
Subclasses of ``np.ndarray`` will be passed through as chunks unchanged.
Parameters
----------
a : array-like
Input data, in any form that can be converted to a dask array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples of
lists and ndarrays.
dtype : data-type, optional
By default, the data-type is inferred from the input data.
order : {‘C’, ‘F’, ‘A’, ‘K’}, optional
Memory layout. ‘A’ and ‘K’ depend on the order of input array a.
‘C’ row-major (C-style), ‘F’ column-major (Fortran-style) memory
representation. ‘A’ (any) means ‘F’ if a is Fortran contiguous, ‘C’
otherwise ‘K’ (keep) preserve input order. Defaults to ‘C’.
like: array-like
Reference object to allow the creation of Dask arrays with chunks
that are not NumPy arrays. If an array-like passed in as ``like``
supports the ``__array_function__`` protocol, the chunk type of the
resulting array will be defined by it. In this case, it ensures the
creation of a Dask array compatible with that passed in via this
argument. If ``like`` is a Dask array, the chunk type of the
resulting array will be defined by the chunk type of ``like``.
Requires NumPy 1.20.0 or higher.
inline_array:
Whether to inline the array in the resulting dask graph. For more information,
see the documentation for ``dask.array.from_array()``.
Returns
-------
out : dask array
Dask array interpretation of a.
Examples
--------
>>> import dask.array as da
>>> import numpy as np
>>> x = np.arange(3)
>>> da.asanyarray(x)
dask.array<array, shape=(3,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
>>> y = [[1, 2, 3], [4, 5, 6]]
>>> da.asanyarray(y)
dask.array<array, shape=(2, 3), dtype=int64, chunksize=(2, 3), chunktype=numpy.ndarray>
.. warning::
`order` is ignored if `a` is an `Array`, has the attribute ``to_dask_array``,
or is a list or tuple of `Array`'s.
"""
if like is None:
if isinstance(a, Array):
return _as_dtype(a, dtype)
elif hasattr(a, "to_dask_array"):
return _as_dtype(a.to_dask_array(), dtype)
elif type(a).__module__.split(".")[0] == "xarray" and hasattr(a, "data"):
return _as_dtype(asarray(a.data, order=order), dtype)
elif isinstance(a, (list, tuple)) and any(isinstance(i, Array) for i in a):
return _as_dtype(stack(a), dtype)
elif not isinstance(getattr(a, "shape", None), Iterable):
a = np.asanyarray(a, dtype=dtype, order=order)
else:
like_meta = meta_from_array(like)
if isinstance(a, Array):
return a.map_blocks(np.asanyarray, like=like_meta, dtype=dtype, order=order)
else:
a = np.asanyarray(a, like=like_meta, dtype=dtype, order=order)
return from_array(
a,
chunks=a.shape,
getitem=getter_inline,
asarray=False,
inline_array=inline_array,
)
def is_scalar_for_elemwise(arg):
"""
>>> is_scalar_for_elemwise(42)
True
>>> is_scalar_for_elemwise('foo')
True
>>> is_scalar_for_elemwise(True)
True
>>> is_scalar_for_elemwise(np.array(42))
True
>>> is_scalar_for_elemwise([1, 2, 3])
True
>>> is_scalar_for_elemwise(np.array([1, 2, 3]))
False
>>> is_scalar_for_elemwise(from_array(np.array(0), chunks=()))
False
>>> is_scalar_for_elemwise(np.dtype('i4'))
True
"""
# the second half of shape_condition is essentially just to ensure that
# dask series / frame are treated as scalars in elemwise.
maybe_shape = getattr(arg, "shape", None)
shape_condition = not isinstance(maybe_shape, Iterable) or any(
is_dask_collection(x) for x in maybe_shape
)
return (
np.isscalar(arg)
or shape_condition
or isinstance(arg, np.dtype)
or (isinstance(arg, np.ndarray) and arg.ndim == 0)
)
def broadcast_shapes(*shapes):
"""
Determines output shape from broadcasting arrays.
Parameters
----------
shapes : tuples
The shapes of the arguments.
Returns
-------
output_shape : tuple
Raises
------
ValueError
If the input shapes cannot be successfully broadcast together.
"""
if len(shapes) == 1:
return shapes[0]
out = []
for sizes in zip_longest(*map(reversed, shapes), fillvalue=-1):
if np.isnan(sizes).any():
dim = np.nan
else:
dim = 0 if 0 in sizes else np.max(sizes).item()
if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
raise ValueError(
"operands could not be broadcast together with "
"shapes {}".format(" ".join(map(str, shapes)))
)
out.append(dim)
return tuple(reversed(out))
def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
"""Apply an elementwise ufunc-like function blockwise across arguments.
Like numpy ufuncs, broadcasting rules are respected.
Parameters
----------
op : callable
The function to apply. Should be numpy ufunc-like in the parameters
that it accepts.
*args : Any
Arguments to pass to `op`. Non-dask array-like objects are first
converted to dask arrays, then all arrays are broadcast together before
applying the function blockwise across all arguments. Any scalar
arguments are passed as-is following normal numpy ufunc behavior.
out : dask array, optional
If out is a dask.array then this overwrites the contents of that array
with the result.
where : array_like, optional
An optional boolean mask marking locations where the ufunc should be
applied. Can be a scalar, dask array, or any other array-like object.
Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
for more information.
dtype : dtype, optional
If provided, overrides the output array dtype.
name : str, optional
A unique key name to use when building the backing dask graph. If not
provided, one will be automatically generated based on the input
arguments.
Examples
--------
>>> elemwise(add, x, y) # doctest: +SKIP
>>> elemwise(sin, x) # doctest: +SKIP
>>> elemwise(sin, x, out=dask_array) # doctest: +SKIP
See Also
--------
blockwise
"""
if kwargs:
raise TypeError(
f"{op.__name__} does not take the following keyword arguments "
f"{sorted(kwargs)}"
)
out = _elemwise_normalize_out(out)
where = _elemwise_normalize_where(where)
args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
shapes = []
for arg in args:
shape = getattr(arg, "shape", ())
if any(is_dask_collection(x) for x in shape):
# Want to exclude Delayed shapes and dd.Scalar
shape = ()
shapes.append(shape)
if isinstance(where, Array):
shapes.append(where.shape)
if isinstance(out, Array):
shapes.append(out.shape)
shapes = [s if isinstance(s, Iterable) else () for s in shapes]
out_ndim = len(
broadcast_shapes(*shapes)
) # Raises ValueError if dimensions mismatch
expr_inds = tuple(range(out_ndim))[::-1]
if dtype is not None:
need_enforce_dtype = True
else:
# We follow NumPy's rules for dtype promotion, which special cases
# scalars and 0d ndarrays (which it considers equivalent) by using
# their values to compute the result dtype:
# https://github.com/numpy/numpy/issues/6240
# We don't inspect the values of 0d dask arrays, because these could
# hold potentially very expensive calculations. Instead, we treat
# them just like other arrays, and if necessary cast the result of op
# to match.
vals = [
(
np.empty((1,) * max(1, a.ndim), dtype=a.dtype)
if not is_scalar_for_elemwise(a)
else a
)
for a in args
]
try:
dtype = apply_infer_dtype(op, vals, {}, "elemwise", suggest_dtype=False)
except Exception:
return NotImplemented
need_enforce_dtype = any(
not is_scalar_for_elemwise(a) and a.ndim == 0 for a in args
)
if not name:
name = f"{funcname(op)}-{tokenize(op, dtype, *args, where)}"
blockwise_kwargs = dict(dtype=dtype, name=name, token=funcname(op).strip("_"))
if where is not True:
blockwise_kwargs["elemwise_where_function"] = op
op = _elemwise_handle_where
args.extend([where, out])
if need_enforce_dtype:
blockwise_kwargs["enforce_dtype"] = dtype
blockwise_kwargs["enforce_dtype_function"] = op
op = _enforce_dtype
result = blockwise(
op,
expr_inds,
*concat(
(a, tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None)
for a in args
),
**blockwise_kwargs,
)
return handle_out(out, result)
def _elemwise_normalize_where(where):
if where is True:
return True
elif where is False or where is None:
return False
return asarray(where)
def _elemwise_handle_where(*args, **kwargs):
function = kwargs.pop("elemwise_where_function")
*args, where, out = args
if hasattr(out, "copy"):
out = out.copy()
return function(*args, where=where, out=out, **kwargs)
def _elemwise_normalize_out(out):
if isinstance(out, tuple):
if len(out) == 1:
out = out[0]
elif len(out) > 1:
raise NotImplementedError("The out parameter is not fully supported")
else:
out = None
if not (out is None or isinstance(out, Array)):
raise NotImplementedError(
f"The out parameter is not fully supported."
f" Received type {type(out).__name__}, expected Dask Array"
)
return out
def handle_out(out, result):
"""Handle out parameters
If out is a dask.array then this overwrites the contents of that array with
the result
"""
out = _elemwise_normalize_out(out)
if isinstance(out, Array):
if out.shape != result.shape:
raise ValueError(
"Mismatched shapes between result and out parameter. "
"out=%s, result=%s" % (str(out.shape), str(result.shape))
)
out._chunks = result.chunks
out.dask = result.dask
out._meta = result._meta
out._name = result.name
return out
else:
return result
def _enforce_dtype(*args, **kwargs):
"""Calls a function and converts its result to the given dtype.
The parameters have deliberately been given unwieldy names to avoid
clashes with keyword arguments consumed by blockwise
A dtype of `object` is treated as a special case and not enforced,
because it is used as a dummy value in some places when the result will
not be a block in an Array.
Parameters
----------
enforce_dtype : dtype
Result dtype
enforce_dtype_function : callable
The wrapped function, which will be passed the remaining arguments
"""
dtype = kwargs.pop("enforce_dtype")
function = kwargs.pop("enforce_dtype_function")
result = function(*args, **kwargs)
if hasattr(result, "dtype") and dtype != result.dtype and dtype != object:
if not np.can_cast(result, dtype, casting="same_kind"):
raise ValueError(
"Inferred dtype from function %r was %r "
"but got %r, which can't be cast using "
"casting='same_kind'"
% (funcname(function), str(dtype), str(result.dtype))
)
if np.isscalar(result):
# scalar astype method doesn't take the keyword arguments, so
# have to convert via 0-dimensional array and back.
result = result.astype(dtype)
else:
try:
result = result.astype(dtype, copy=False)
except TypeError:
# Missing copy kwarg
result = result.astype(dtype)
return result
def broadcast_to(x, shape, chunks=None, meta=None):
"""Broadcast an array to a new shape.
Parameters
----------
x : array_like
The array to broadcast.
shape : tuple
The shape of the desired array.
chunks : tuple, optional
If provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to is more efficient than rechunking afterwards. Chunks are
only allowed to differ from the original shape along dimensions that
are new on the result or have size 1 the input array.
meta : empty ndarray
empty ndarray created with same NumPy backend, ndim and dtype as the
Dask Array being created (overrides dtype)
Returns
-------
broadcast : dask array
See Also
--------
:func:`numpy.broadcast_to`
"""
x = asarray(x)
shape = tuple(shape)
if meta is None:
meta = meta_from_array(x)
if x.shape == shape and (chunks is None or chunks == x.chunks):
return x
ndim_new = len(shape) - x.ndim
if ndim_new < 0 or any(
new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1
):
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
if chunks is None:
chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
bd if old > 1 else (new,)
for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
)
else:
chunks = normalize_chunks(
chunks, shape, dtype=x.dtype, previous_chunks=x.chunks
)
for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
if old_bd != new_bd and old_bd != (1,):
raise ValueError(
"cannot broadcast chunks %s to chunks %s: "
"new chunks must either be along a new "
"dimension or a dimension of size 1" % (x.chunks, chunks)
)
name = "broadcast_to-" + tokenize(x, shape, chunks)
dsk = {}
enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
for new_index, chunk_shape in (zip(*ec) for ec in enumerated_chunks):
old_index = tuple(
0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:])
)
old_key = (x.name,) + old_index
new_key = (name,) + new_index
dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
return Array(graph, name, chunks, dtype=x.dtype, meta=meta)
@derived_from(np)
def broadcast_arrays(*args, subok=False):
subok = bool(subok)
to_array = asanyarray if subok else asarray
args = tuple(to_array(e) for e in args)
# Unify uneven chunking
inds = [list(reversed(range(x.ndim))) for x in args]
uc_args = concat(zip(args, inds))
_, args = unify_chunks(*uc_args, warn=False)
shape = broadcast_shapes(*(e.shape for e in args))
chunks = broadcast_chunks(*(e.chunks for e in args))
if NUMPY_GE_200:
result = tuple(broadcast_to(e, shape=shape, chunks=chunks) for e in args)
else:
result = [broadcast_to(e, shape=shape, chunks=chunks) for e in args]
return result
def offset_func(func, offset, *args):
"""Offsets inputs by offset
>>> double = lambda x: x * 2
>>> f = offset_func(double, (10,))
>>> f(1)
22
>>> f(300)
620
"""
def _offset(*args):
args2 = list(map(add, args, offset))
return func(*args2)
with contextlib.suppress(Exception):
_offset.__name__ = "offset_" + func.__name__
return _offset
def chunks_from_arrays(arrays):
"""Chunks tuple from nested list of arrays
>>> x = np.array([1, 2])
>>> chunks_from_arrays([x, x])
((2, 2),)
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x], [x]])
((1, 1), (2,))
>>> x = np.array([[1, 2]])
>>> chunks_from_arrays([[x, x]])
((1,), (2, 2))
>>> chunks_from_arrays([1, 1])
((1, 1),)
"""
if not arrays:
return ()
result = []
dim = 0
def shape(x):
try:
return x.shape if x.shape else (1,)
except AttributeError:
return (1,)
while isinstance(arrays, (list, tuple)):
> result.append(tuple(shape(deepfirst(a))[dim] for a in arrays))
E IndexError: tuple index out of range
../../../miniconda3/envs/dask-distributed/lib/python3.10/site-packages/dask/array/core.py:5281: IndexError
Check warning on line 0 in distributed.tests.test_stress
github-actions / Unit Test Results
11 out of 12 runs failed: test_stress_creation_and_deletion (distributed.tests.test_stress)
artifacts/macos-latest-3.12-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_expr-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.10-no_queue-ci1/pytest.xml [took 8s]
artifacts/ubuntu-latest-3.11-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-3.12-default-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-mindeps-numpy-ci1/pytest.xml [took 7s]
artifacts/ubuntu-latest-mindeps-pandas-ci1/pytest.xml [took 12s]
artifacts/windows-latest-3.10-default-ci1/pytest.xml [took 7s]
artifacts/windows-latest-3.11-default-ci1/pytest.xml [took 8s]
artifacts/windows-latest-3.12-default-ci1/pytest.xml [took 8s]
Raw output
AssertionError: assert 'round-bb724227e9bb17988ea1ad61e539e2b8' == 8000884.93
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33099', workers: 0, cores: 0, tasks: 0>
@pytest.mark.slow
@gen_cluster(
nthreads=[],
client=True,
scheduler_kwargs={"allowed_failures": 100_000},
)
async def test_stress_creation_and_deletion(c, s):
# Assertions are handled by the validate mechanism in the scheduler
pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
rng = da.random.RandomState(0)
x = rng.random(size=(2000, 2000), chunks=(100, 100))
y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2)
z = c.persist(y)
async def create_and_destroy_worker(delay):
start = time()
while time() < start + 5:
async with Worker(s.address, nthreads=2) as n:
await asyncio.sleep(delay)
await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20)))
async with Worker(s.address, nthreads=2):
> assert await c.compute(z) == 8000884.93
E AssertionError: assert 'round-bb724227e9bb17988ea1ad61e539e2b8' == 8000884.93
distributed/tests/test_stress.py:125: AssertionError
Check warning on line 0 in distributed.cli.tests.test_dask_spec
github-actions / Unit Test Results
1 out of 12 runs failed: test_errors (distributed.cli.tests.test_dask_spec)
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 10s]
Raw output
subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'spec', '--spec', '{"foo": "bar"}', '--spec-file', 'foo.yaml']' timed out after 10 seconds
def test_errors():
> with popen(
[
"dask",
"spec",
"--spec",
'{"foo": "bar"}',
"--spec-file",
"foo.yaml",
],
capture_output=True,
) as proc:
distributed/cli/tests/test_dask_spec.py:73:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.10/contextlib.py:142: in __exit__
next(self.gen)
distributed/utils_test.py:1204: in popen
_terminate_process(proc, terminate_timeout)
distributed/utils_test.py:1130: in _terminate_process
proc.communicate(timeout=terminate_timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.10/subprocess.py:1154: in communicate
stdout, stderr = self._communicate(input, endtime, timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.10/subprocess.py:2022: in _communicate
self._check_timeout(endtime, orig_timeout, stdout, stderr)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Popen: returncode: -9 args: ['/home/runner/miniconda3/envs/dask-distributed...>
endtime = 338.631460995, orig_timeout = 10
stdout_seq = [b'Exception ignored in atexit callback: <function _close_global_client at 0x7f23fc83a4d0>\nTraceback (most recent cal...ackages/coverage/collector.py", line 252, in lock_data\n', b' self.data_lock.acquire()\nKeyboardInterrupt: \n', b'']
stderr_seq = None, skip_check_and_raise = False
def _check_timeout(self, endtime, orig_timeout, stdout_seq, stderr_seq,
skip_check_and_raise=False):
"""Convenience for checking if a timeout has expired."""
if endtime is None:
return
if skip_check_and_raise or _time() > endtime:
> raise TimeoutExpired(
self.args, orig_timeout,
output=b''.join(stdout_seq) if stdout_seq else None,
stderr=b''.join(stderr_seq) if stderr_seq else None)
E subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'spec', '--spec', '{"foo": "bar"}', '--spec-file', 'foo.yaml']' timed out after 10 seconds
../../../miniconda3/envs/dask-distributed/lib/python3.10/subprocess.py:1198: TimeoutExpired
Check warning on line 0 in distributed.cli.tests.test_dask_worker
github-actions / Unit Test Results
1 out of 12 runs failed: test_nanny_worker_port_range_too_many_workers_raises (distributed.cli.tests.test_dask_worker)
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 10s]
Raw output
subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'worker', 'tcp://127.0.0.1:34913', '--nworkers', '3', '--host', '127.0.0.1', '--worker-port', '9684:9685', '--nanny-port', '9686:9687', '--no-dashboard']' timed out after 10 seconds
s = <Scheduler 'tcp://127.0.0.1:34913', workers: 0, cores: 0, tasks: 0>
@gen_cluster(nthreads=[])
async def test_nanny_worker_port_range_too_many_workers_raises(s):
> with popen(
[
"dask",
"worker",
s.address,
"--nworkers",
"3",
"--host",
"127.0.0.1",
"--worker-port",
"9684:9685",
"--nanny-port",
"9686:9687",
"--no-dashboard",
],
capture_output=True,
) as worker:
distributed/cli/tests/test_dask_worker.py:244:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/dask-distributed/lib/python3.11/contextlib.py:144: in __exit__
next(self.gen)
distributed/utils_test.py:1204: in popen
_terminate_process(proc, terminate_timeout)
distributed/utils_test.py:1130: in _terminate_process
proc.communicate(timeout=terminate_timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.11/subprocess.py:1209: in communicate
stdout, stderr = self._communicate(input, endtime, timeout)
../../../miniconda3/envs/dask-distributed/lib/python3.11/subprocess.py:2116: in _communicate
self._check_timeout(endtime, orig_timeout, stdout, stderr)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Popen: returncode: -9 args: ['/home/runner/miniconda3/envs/dask-distributed...>
endtime = 533.565681173, orig_timeout = 10
stdout_seq = [b'Exception ignored in atexit callback: <function close_clusters at 0x7f704b422200>\nTraceback (most recent call last...verage/collector.py", line 254, in unlock_data\n', b' def unlock_data(self) -> None:\n\nKeyboardInterrupt: \n', b'']
stderr_seq = None, skip_check_and_raise = False
def _check_timeout(self, endtime, orig_timeout, stdout_seq, stderr_seq,
skip_check_and_raise=False):
"""Convenience for checking if a timeout has expired."""
if endtime is None:
return
if skip_check_and_raise or _time() > endtime:
> raise TimeoutExpired(
self.args, orig_timeout,
output=b''.join(stdout_seq) if stdout_seq else None,
stderr=b''.join(stderr_seq) if stderr_seq else None)
E subprocess.TimeoutExpired: Command '['/home/runner/miniconda3/envs/dask-distributed/bin/dask', 'worker', 'tcp://127.0.0.1:34913', '--nworkers', '3', '--host', '127.0.0.1', '--worker-port', '9684:9685', '--nanny-port', '9686:9687', '--no-dashboard']' timed out after 10 seconds
../../../miniconda3/envs/dask-distributed/lib/python3.11/subprocess.py:1253: TimeoutExpired
Check warning on line 0 in distributed.comm.tests.test_comms
github-actions / Unit Test Results
3 out of 12 runs failed: test_tls_comm_closed_implicit[tornado] (distributed.comm.tests.test_comms)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
ssl.SSLError: [SYS] unknown error (_ssl.c:2580)
tcp = <module 'distributed.comm.tcp' from '/home/runner/work/distributed/distributed/distributed/comm/tcp.py'>
@gen_test()
async def test_tls_comm_closed_implicit(tcp):
> await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs)
distributed/comm/tests/test_comms.py:777:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/comm/tests/test_comms.py:763: in check_comm_closed_implicit
await comm.read()
distributed/comm/tcp.py:225: in read
frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:422: in read_bytes
self._try_inline_read()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:836: in _try_inline_read
pos = self._read_to_buffer_loop()
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:750: in _read_to_buffer_loop
if self._read_to_buffer() == 0:
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:861: in _read_to_buffer
bytes_read = self.read_from_fd(buf)
../../../miniconda3/envs/dask-distributed/lib/python3.11/site-packages/tornado/iostream.py:1552: in read_from_fd
return self.socket.recv_into(buf, len(buf))
../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1314: in recv_into
return self.read(nbytes, buffer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ssl.SSLSocket [closed] fd=-1, family=2, type=1, proto=0>, len = 65536
buffer = bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x...0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
if self._sslobj is None:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
if buffer is not None:
> return self._sslobj.read(len, buffer)
E ssl.SSLError: [SYS] unknown error (_ssl.c:2580)
../../../miniconda3/envs/dask-distributed/lib/python3.11/ssl.py:1166: SSLError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[p2p-tasks] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 5eac77ca35612a3e8c9f0e275e3e0020 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '5eac77ca35612a3e8c9f0e275e3e0020'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and … await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64012', workers: 0, cores: 0, tasks: 0>
config_value = 'tasks', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:64013', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64016', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.25269457, 0.21923049, 0.09456582, 0.41318158, 0.26615119,
0.53192806, 0.44881814, 0.80442171, 0.7545..., 0.03274378, 0.17641201, 0.04475437, 0.56749567,
0.69504577, 0.59694585, 0.23471839, 0.63257509, 0.62788706]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 5eac77ca35612a3e8c9f0e275e3e0020 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[p2p-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7ed88ee600fafb6fa95cf3f00f42eaf5 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '7ed88ee600fafb6fa95cf3f00f42eaf5'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and … = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64026', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:64027', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64030', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.70050953, 0.59988581, 0.2873489 , 0.02919635, 0.75044654,
0.34847269, 0.00570634, 0.79608935, 0.0933..., 0.19892204, 0.22475293, 0.7469204 , 0.13692289,
0.71068667, 0.03039515, 0.84743146, 0.23023854, 0.91232044]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 7ed88ee600fafb6fa95cf3f00f42eaf5 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[p2p-None] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P c27782cfc676cfba0abb18c164e6cb85 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'c27782cfc676cfba0abb18c164e6cb85'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …n = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64040', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:64041', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64044', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.40426865, 0.81855948, 0.50957116, 0.44965013, 0.25501785,
0.20329874, 0.75418044, 0.10925094, 0.2364..., 0.76864058, 0.805811 , 0.93629147, 0.94224933,
0.3873398 , 0.336608 , 0.01348985, 0.7104567 , 0.8208412 ]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P c27782cfc676cfba0abb18c164e6cb85 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[None-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 35f3712b5c49a47f3324d11460012f87 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '35f3712b5c49a47f3324d11460012f87'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …n = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64069', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = None
ws = (<Worker 'tcp://127.0.0.1:64070', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64073', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.03641802, 0.79395381, 0.07845067, 0.75119799, 0.34884585,
0.86607563, 0.02201857, 0.38369563, 0.9759..., 0.26077311, 0.62562275, 0.61415159, 0.22621115,
0.02064385, 0.83236264, 0.008683 , 0.42213654, 0.03328439]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 35f3712b5c49a47f3324d11460012f87 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[None-None] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P e95e40ae5ccb7f71758613ef093cea03 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'e95e40ae5ccb7f71758613ef093cea03'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …un = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64083', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = None
ws = (<Worker 'tcp://127.0.0.1:64084', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64087', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.3038855 , 0.72204286, 0.59516751, 0.62391027, 0.57077081,
0.18230923, 0.15963029, 0.88508962, 0.5943..., 0.7712572 , 0.77847078, 0.52208647, 0.62879176,
0.45921446, 0.22814207, 0.31759614, 0.60266988, 0.17790265]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P e95e40ae5ccb7f71758613ef093cea03 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_heuristic[new0-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 031e1a73eb9bc8ea0efa30aa81c3dd68 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '031e1a73eb9bc8ea0efa30aa81c3dd68'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …
distributed\shuffle\_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed\utils.py:439: in sync
raise error
distributed\utils.py:413: in f
result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64097', workers: 0, cores: 0, tasks: 0>
a = array([[0.47181678, 0.47249156, 0.78796815, ..., 0.85820206, 0.9811494 ,
0.55305409],
[0.21465893, 0.37...21,
0.94585984],
[0.36259846, 0.45029292, 0.94920195, ..., 0.3275337 , 0.03679922,
0.90015213]])
b = <Worker 'tcp://127.0.0.1:64101', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((1, 1, 1, 1, 1, 1, ...), (100,)), expected_algorithm = 'p2p'
@pytest.mark.parametrize(
["new", "expected_algorithm"],
[
# All-to-all rechunking defaults to P2P
(((1,) * 100, (100,)), "p2p"),
# Localized rechunking defaults to tasks
(((50, 50), (2,) * 50), "tasks"),
# Less local rechunking first defaults to tasks,
(((25, 25, 25, 25), (4,) * 25), "tasks"),
# then switches to p2p
(((10,) * 10, (10,) * 10), "p2p"),
],
)
@gen_cluster(client=True)
async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
x = da.from_array(a, chunks=(100, 1))
x2 = rechunk(x, chunks=new)
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
else:
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 031e1a73eb9bc8ea0efa30aa81c3dd68 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_heuristic[new3-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 5b9d4830bd1782a20359d383dac64ddc failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '5b9d4830bd1782a20359d383dac64ddc'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …_plugin.py:411: in get_or_create_shuffle
return sync(
distributed\utils.py:439: in sync
raise error
distributed\utils.py:413: in f
result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64151', workers: 0, cores: 0, tasks: 0>
a = array([[0.02218022, 0.54394215, 0.75067916, ..., 0.52673533, 0.10159338,
0.32048149],
[0.64108086, 0.21...08,
0.18661381],
[0.77971001, 0.88125599, 0.40857319, ..., 0.48985381, 0.50061773,
0.82015301]])
b = <Worker 'tcp://127.0.0.1:64155', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((10, 10, 10, 10, 10, 10, ...), (10, 10, 10, 10, 10, 10, ...))
expected_algorithm = 'p2p'
@pytest.mark.parametrize(
["new", "expected_algorithm"],
[
# All-to-all rechunking defaults to P2P
(((1,) * 100, (100,)), "p2p"),
# Localized rechunking defaults to tasks
(((50, 50), (2,) * 50), "tasks"),
# Less local rechunking first defaults to tasks,
(((25, 25, 25, 25), (4,) * 25), "tasks"),
# then switches to p2p
(((10,) * 10, (10,) * 10), "p2p"),
],
)
@gen_cluster(client=True)
async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
x = da.from_array(a, chunks=(100, 1))
x2 = rechunk(x, chunks=new)
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
else:
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed\shuffle\tests\test_rechunk.py:239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 5b9d4830bd1782a20359d383dac64ddc failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_cull_p2p_rechunk_independent_partitions (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
assert 57 < (228 / 4)
+ where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n)
+ where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
+ and 228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n)
+ where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64166', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64167', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64170', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[[4.01332956e-01, 5.05156722e-01, 2.80394807e-02, 5.91861613e-01,
6.20283724e-01, 3.06173020e-02, 5.19...1,
3.34170490e-01, 3.41256630e-01, 6.56954053e-01, 5.49743569e-01,
4.71146730e-01, 8.40362257e-01]]])
x = dask.array<array, shape=(10, 10, 10), dtype=float64, chunksize=(1, 5, 1), chunktype=numpy.ndarray>
new = (5, 1, -1)
@gen_cluster(client=True)
async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws):
a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10))
x = da.from_array(a, chunks=(1, 5, 1))
new = (5, 1, -1)
rechunked = rechunk(x, chunks=new, method="p2p")
(dsk,) = dask.optimize(rechunked)
culled = rechunked[:5, :2]
(dsk_culled,) = dask.optimize(culled)
# The culled graph requires only 1/2 of the input tasks
n_inputs = len(
[1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")]
)
n_culled_inputs = len(
[
1
for key in dsk_culled.dask.get_all_dependencies()
if key[0].startswith("array-")
]
)
assert n_culled_inputs == n_inputs / 4
# The culled graph should also have less than 1/4 the tasks
> assert len(dsk_culled.dask) < len(dsk.dask) / 4
E assert 57 < (228 / 4)
E + where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n)
E + where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa05ebaf50>\n 0. getitem-3805afa155d01ec6f19a98fd58a271e4\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
E + and 228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n)
E + where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x1aa065b1050>\n 0. rechunk-p2p-c47ad159278aa77d47afe1fdcb337370\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
distributed\shuffle\tests\test_rechunk.py:265: AssertionError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_expand (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7df54f410e79f8489793c71b35ec9a5d failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '7df54f410e79f8489793c71b35ec9a5d'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …ger
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
> yield
distributed\shuffle\_core.py:523:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\shuffle\_rechunk.py:170: in rechunk_transfer
return get_worker_plugin().add_partition(
distributed\shuffle\_worker_plugin.py:348: in add_partition
shuffle_run = self.get_or_create_shuffle(id)
distributed\shuffle\_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed\utils.py:439: in sync
raise error
distributed\utils.py:413: in f
result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64264', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64265', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64268', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.4588262 , 0.29882582, 0.71928598, 0.22034295, 0.98812948,
0.67941824, 0.25921056, 0.62241972, 0.3004..., 0.28284222, 0.16402205, 0.55920114, 0.6132267 ,
0.17587811, 0.09622778, 0.36399375, 0.04524015, 0.74855889]])
x = dask.array<array, shape=(10, 10), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
y = dask.array<rechunk-p2p, shape=(10, 10), dtype=float64, chunksize=(3, 3), chunktype=numpy.ndarray>
@gen_cluster(client=True)
async def test_rechunk_expand(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_expand
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(5, 5))
y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
> assert np.all(await c.compute(y) == a)
distributed\shuffle\tests\test_rechunk.py:377:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 7df54f410e79f8489793c71b35ec9a5d failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_expand2 (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 6e5cf1da37a31a6c56bc50f1842218f5 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '6e5cf1da37a31a6c56bc50f1842218f5'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …rors(id: ShuffleId) -> Iterator[None]:
try:
> yield
distributed\shuffle\_core.py:523:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\shuffle\_rechunk.py:170: in rechunk_transfer
return get_worker_plugin().add_partition(
distributed\shuffle\_worker_plugin.py:348: in add_partition
shuffle_run = self.get_or_create_shuffle(id)
distributed\shuffle\_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed\utils.py:439: in sync
raise error
distributed\utils.py:413: in f
result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64279', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64280', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64283', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = 3, b = 2
orig = array([[0.57459479, 0.24261869, 0.6206557 ],
[0.97627947, 0.65759 , 0.41902488],
[0.08888248, 0.69644868, 0.30588096]])
@gen_cluster(client=True)
async def test_rechunk_expand2(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_expand2
"""
(a, b) = (3, 2)
orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
for off, off2 in product(range(1, a - 1), range(1, a - 1)):
old = ((a - off, off),) * b
x = da.from_array(orig, chunks=old)
new = ((a - off2, off2),) * b
assert np.all(await c.compute(x.rechunk(chunks=new, method="p2p")) == orig)
if a - off - off2 > 0:
new = ((off, a - off2 - off, off2),) * b
> y = await c.compute(x.rechunk(chunks=new, method="p2p"))
distributed\shuffle\tests\test_rechunk.py:396:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 6e5cf1da37a31a6c56bc50f1842218f5 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_unknown_from_pandas (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P db9c09cc2dbfe69e4bcd9524cb9be9d7 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'db9c09cc2dbfe69e4bcd9524cb9be9d7'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and … get_or_create_shuffle
return sync(
distributed\utils.py:439: in sync
raise error
distributed\utils.py:413: in f
result = yield future
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\site-packages\tornado\gen.py:766: in run
value = future.result()
distributed\shuffle\_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed\shuffle\_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64576', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:64577', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64580', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
pd = <module 'pandas' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\pandas\\__init__.py'>
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
arr = array([[-1.21493258e+00, 8.70655313e-01, 8.33382075e-01,
8.51173412e-01, 1.41451004e+00, -1.98203709e-01,
... 1.52552622e+00, -7.93934769e-01,
-3.30697106e+00, 6.96784483e-01, -1.59570089e-02,
-1.61997549e+00]])
@gen_cluster(client=True)
async def test_rechunk_unknown_from_pandas(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas
"""
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
arr = np.random.default_rng().standard_normal((50, 10))
x = dd.from_pandas(pd.DataFrame(arr), 2).values
result = x.rechunk((None, (5, 5)), method="p2p")
assert np.isnan(x.chunks[0]).all()
assert np.isnan(result.chunks[0]).all()
assert result.chunks[1] == (5, 5)
expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk(
(None, (5, 5)), method="p2p"
)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:706:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P db9c09cc2dbfe69e4bcd9524cb9be9d7 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x0-chunks0] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '53a5056f80ef44da75580c3c1aff19ce'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …n _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64605', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:64606', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64609', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x1-chunks1] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '53a5056f80ef44da75580c3c1aff19ce'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64619', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:64620', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64623', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x2-chunks2] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '53a5056f80ef44da75580c3c1aff19ce'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …fresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64633', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:64634', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64637', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 53a5056f80ef44da75580c3c1aff19ce failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x3-chunks3] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '103aea8fc8e7486296d569fc1f0805a8'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …n _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64647', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:64648', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64651', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x4-chunks4] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '103aea8fc8e7486296d569fc1f0805a8'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64662', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:64663', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64666', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 103aea8fc8e7486296d569fc1f0805a8 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x5-chunks5] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P ba094fb6021231c810a48e86cbce8897 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'ba094fb6021231c810a48e86cbce8897'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …fresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64677', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:64678', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64681', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P ba094fb6021231c810a48e86cbce8897 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x6-chunks6] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P b5a56a723b010b3ef762d5c14fe583f9 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed\shuffle\_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'b5a56a723b010b3ef762d5c14fe583f9'
distributed\shuffle\_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and …n _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed\shuffle\_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed\core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed\core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed\core.py:832: in _handle_comm
result = handler(**msg)
distributed\shuffle\_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed\shuffle\_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed\shuffle\_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:64697', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:64698', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:64701', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from 'C:\\Users\\runneradmin\\miniconda3\\envs\\dask-distributed\\Lib\\site-packages\\dask\\dataframe\\__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed\shuffle\tests\test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed\client.py:410: in _result
raise exc.with_traceback(tb)
distributed\shuffle\_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
C:\Users\runneradmin\miniconda3\envs\dask-distributed\Lib\contextlib.py:158: in __exit__
self.gen.throw(typ, value, traceback)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P b5a56a723b010b3ef762d5c14fe583f9 failed during transfer phase
distributed\shuffle\_core.py:531: RuntimeError