Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): optimize subsetting dask array #1432

Merged
merged 5 commits into from
Mar 22, 2024
Merged

Conversation

ilan-gold
Copy link
Contributor

@ilan-gold ilan-gold commented Mar 21, 2024

@ilan-gold
Copy link
Contributor Author

ilan-gold commented Mar 21, 2024

MVCE as motivation for doing this

import dask.array as da
import numpy as np
import scipy as sp

# DENSE
arr = np.random.randn(chunksize, chunksize)
X = da.map_blocks(lambda block_id: arr, dtype=arr.dtype, meta=arr, chunks=((chunksize, ) * (size // chunksize),) * 2)

%timeit X.vindex[np.ix_(index_0, index_1)] # slow
%timeit X[index_0, :][:, index_1] # fast

np.array_equal(X[index_0, :][:, index_1].compute(), X.vindex[np.ix_(index_0, index_1)].compute())

# SPARSE
arr = sp.sparse.random(chunksize, chunksize, format="csr", density=.1)
X = da.map_blocks(lambda block_id: arr, dtype=arr.dtype, meta=arr, chunks=((chunksize, ) * (size // chunksize),) * 2)

%timeit X.vindex[np.ix_(index_0, index_1)]
%timeit X[index_0, :][:, index_1]

np.array_equal(X[index_0, :][:, index_1].compute().toarray(), X.vindex[np.ix_(index_0, index_1)].compute().toarray())

I don't understand why the sparse one doesn't compute in the last step but in any case, the non vindex op is much faster and also works so I'm not sure it really matters why vindex doesn't compute since where are removing this anyway....

@ilan-gold ilan-gold added this to the 0.10.7 milestone Mar 21, 2024
Copy link

codecov bot commented Mar 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.08%. Comparing base (98d33da) to head (34e3476).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1432      +/-   ##
==========================================
- Coverage   86.26%   84.08%   -2.18%     
==========================================
  Files          36       36              
  Lines        5612     5599      -13     
==========================================
- Hits         4841     4708     -133     
- Misses        771      891     +120     
Flag Coverage Δ
gpu-tests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
anndata/_core/index.py 93.19% <100.00%> (-0.14%) ⬇️

... and 10 files with indirect coverage changes

# TODO: this may have been working for some cases?
subset_idx = np.ix_(*subset_idx)
return a.vindex[subset_idx]
return a[subset_idx[0], :][:, subset_idx[1]]
Copy link
Contributor Author

@ilan-gold ilan-gold Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dask does not support a[subset_idx] when subset_idx has more than one entry

anndata/_core/anndata.py:1506: in copy
    X=_subset(self._adata_ref.X, (self._oidx, self._vidx)).copy()
/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/functools.py:909: in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
anndata/_core/index.py:155: in _subset_dask
    return a[subset_idx]
venv/lib/python3.11/site-packages/dask/array/core.py:1994: in __getitem__
    dsk, chunks = slice_array(out, self.name, self.chunks, index2, self.itemsize)
venv/lib/python3.11/site-packages/dask/array/slicing.py:176: in slice_array
    dsk_out, bd_out = slice_with_newaxes(out_name, in_name, blockdims, index, itemsize)
venv/lib/python3.11/site-packages/dask/array/slicing.py:198: in slice_with_newaxes
    dsk, blockdims2 = slice_wrap_lists(out_name, in_name, blockdims, index2, itemsize)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

out_name = 'getitem-32365ec69f5d5f165e6565bb934d931b', in_name = 'array-780508e68d811416a0a1a22cb32db79f', blockdims = ((30,), (15,))
index = (array([ 0,  2,  4,  9, 11, 12, 13, 14, 16, 17, 20, 21, 22, 25, 27, 28, 29]), array([ 3,  6, 10])), itemsize = 4

    def slice_wrap_lists(out_name, in_name, blockdims, index, itemsize):
        """
        Fancy indexing along blocked array dasks

        Handles index of type list.  Calls slice_slices_and_integers for the rest

        See Also
        --------

        take : handle slicing with lists ("fancy" indexing)
        slice_slices_and_integers : handle slicing with slices and integers
        """
        assert all(isinstance(i, (slice, list, Integral)) or is_arraylike(i) for i in index)
        if not len(blockdims) == len(index):
            raise IndexError("Too many indices for array")

        # Do we have more than one list in the index?
        where_list = [
            i for i, ind in enumerate(index) if is_arraylike(ind) and ind.ndim > 0
        ]
        if len(where_list) > 1:
>           raise NotImplementedError("Don't yet support nd fancy indexing")
E           NotImplementedError: Don't yet support nd fancy indexing

venv/lib/python3.11/site-packages/dask/array/slicing.py:244: NotImplementedError

@ilan-gold
Copy link
Contributor Author

ilan-gold commented Mar 21, 2024

/Users/ilangold/Projects/Theis/anndata/anndata/_core/index.py:153: PerformanceWarning: Slicing is producing a large chunk. To accept the large
chunk and silence this warning, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return a[subset_idx[0], :][:, subset_idx[1]]

When using this with large datasets via filter_cells for example....

UPDATE:

Ok the warning arises from

Dask warns when indexing like this produces a chunk that’s 5x larger than the array.chunk-size config option. You have two options to deal with that warning:

    Set dask.config.set({"array.slicing.split_large_chunks": False}) to allow the large chunk and silence the warning.

    Set dask.config.set({"array.slicing.split_large_chunks": True}) to avoid creating the large chunk in the first place.

The right choice will depend on your downstream operations. See [Chunks](https://docs.dask.org/en/latest/array-chunks.html#array-chunks) for more on choosing chunk sizes.

In other words, this is a global warning and the default for the array.chunk-size is 128 MB. So this is not a bug

@ilan-gold ilan-gold marked this pull request as ready for review March 22, 2024 11:55
@ilan-gold ilan-gold enabled auto-merge (squash) March 22, 2024 11:57
@flying-sheep
Copy link
Member

With my change, the dependencies changed like this:

--- 2024-03-21.txt      2024-03-22 15:46:21.047726071 +0100
+++ 2024-03-22.txt      2024-03-22 15:46:49.230697833 +0100
@@ -1,12 +1,12 @@
-anndata           0.11.0.dev90+g8f4c755
+anndata           0.11.0.dev95+g9db28d5
 anyio             4.3.0
-array-api-compat  1.4.1
+array-api-compat  1.5.1
 asciitree         0.3.3
 memray            1.11.0
 msgpack           1.0.8
 natsort           8.4.0
 networkx          3.2.1
-numba             0.59.0
+numba             0.59.1
 numcodecs         0.12.1
 numpy             1.26.4
 numpy-groupies    0.10.2
@@ -26,7 +26,7 @@
 pytest            8.1.1
 pytest-cov        4.1.0
 pytest-memray     1.5.0
-pytest-mock       3.12.0
+pytest-mock       3.14.0
 pytest-xdist      3.5.0
 python-dateutil   2.9.0.post0
 pytz              2024.1
@@ -34,7 +34,7 @@
 rich              13.7.1
 scanpy            1.10.0rc2
 scikit-learn      1.4.1.post1
-scipy             1.12.0
+scipy             1.13.0rc1
 seaborn           0.13.2
 session-info      1.0.0
 setuptools        69.2.0
@@ -46,7 +46,7 @@
 stdlib-list       0.10.0
 tblib             3.0.0
 textual           0.53.1
-threadpoolctl     3.3.0
+threadpoolctl     3.4.0
 toolz             0.12.1
 tornado           6.4
 tqdm              4.66.2
@@ -55,7 +55,7 @@
 uc-micro-py       1.0.3
 umap-learn        0.5.5
 urllib3           2.2.1
-uv                0.1.22
+uv                0.1.23
 zarr              2.17.1
 zict              3.0.0

@ilan-gold ilan-gold merged commit 507444a into main Mar 22, 2024
15 checks passed
@ilan-gold ilan-gold deleted the ig/subsetting_dask_array branch March 22, 2024 15:43
@ivirshup
Copy link
Member

@ilan-gold, did you find or open a bug on dask for this?

@ilan-gold
Copy link
Contributor Author

@ivirshup is this a bug? I guess it's a performance issue, I can open an issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants