Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): pre-processing functions for dask with sparse chunks #2856

Merged
merged 96 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 88 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
d3163cc
(chore): add dask sparse chunks creation
ilan-gold Feb 27, 2024
2a7a54c
(feat): add dask summation
ilan-gold Feb 27, 2024
8dd9a7a
(refactor): `materialize_as_ndarray` needs to operate on indidiual da…
ilan-gold Feb 27, 2024
d55b6a4
(feat): `filter_genes` and `filter_cells`
ilan-gold Feb 27, 2024
41a5f15
(feat): normalization
ilan-gold Feb 27, 2024
e36699e
(fix) `lop1p` tests working
ilan-gold Feb 27, 2024
da6eff0
(refactor): clean up writing test
ilan-gold Feb 27, 2024
63ca2f0
(refactor): us `da.chunk.sum`
ilan-gold Feb 27, 2024
fd22a19
(fix): remove `Client`
ilan-gold Feb 27, 2024
8b9a792
(refactor): remove unnecessary `count_nonzero`
ilan-gold Feb 27, 2024
1592571
(fix): change expected fail on sparse normalization
ilan-gold Feb 27, 2024
6ac32e5
(fix): update comment
ilan-gold Feb 28, 2024
78a3ab6
(feat): `_get_mean_var` dask
ilan-gold Feb 28, 2024
60bbdb8
(feat): clean up tests for what should/should not work
ilan-gold Feb 28, 2024
0c1f254
(refactor): `_compat.sum` to `_utils.elem_sum`
ilan-gold Feb 28, 2024
2f4d11a
(chore): add `elem_sum` test
ilan-gold Feb 28, 2024
12502e8
(refactor): `elem_sum` -> `axis_sum`
ilan-gold Feb 29, 2024
b3bb95a
(feat): add `scale` support
ilan-gold Feb 29, 2024
2b6f717
(fix): maintain dtype
ilan-gold Feb 29, 2024
448dc40
(chore): add back condition
ilan-gold Feb 29, 2024
7226bf0
(fix): use `sum` when needed
ilan-gold Feb 29, 2024
2bc7c3a
(chore): release notes
ilan-gold Feb 29, 2024
62c75fc
(fx): don't use `mean_func` name twice
ilan-gold Feb 29, 2024
b987a68
(chore): revert sparse-chunks-in-dask
ilan-gold Feb 29, 2024
902238a
(chore): type hint
ilan-gold Feb 29, 2024
0abbab5
(chore): check `test_compare_to_upstream`
ilan-gold Feb 29, 2024
a8606ae
(chore): remove comment
ilan-gold Feb 29, 2024
feac6bc
(chore): allow passing `dtype` arg in `axis_sum`
ilan-gold Feb 29, 2024
bcdeddb
(fix): revert fixture changes
ilan-gold Feb 29, 2024
4716d8f
(refactor): cleaner with `array_type` conversion before if-then
ilan-gold Feb 29, 2024
3912b63
(chore): clarify hvg support
ilan-gold Feb 29, 2024
c884c63
(chore): handle array types better
ilan-gold Mar 1, 2024
af351d4
(chore): clean up `materialize_as_ndarray`
ilan-gold Mar 1, 2024
da22953
(chore): fix typing/dispatch problem in 3.9
ilan-gold Mar 1, 2024
dbbc6a2
(chore): `list` type -> `Callable`
ilan-gold Mar 1, 2024
6a4f0c5
(feat): `row_divide` for better division handling
ilan-gold Mar 1, 2024
c0182cb
(fix): use `tuple` for `ARRAY_TYPEXXXX`
ilan-gold Mar 1, 2024
a065a78
(refactor): `mean_func` -> `axis_mean` + types
ilan-gold Mar 1, 2024
c3ee138
(chore): remove unnecessary aggregation
ilan-gold Mar 1, 2024
c246a41
(fix): raise `ValueError` for summing over more than one axis
ilan-gold Mar 1, 2024
743c327
(fix): grammar
ilan-gold Mar 1, 2024
db88560
(fix): better type hints
ilan-gold Mar 1, 2024
2a5faa6
(revert): use old `test_normalize_total` siince we have `csr`
ilan-gold Mar 1, 2024
d6ceb4c
(revert): extraneous diff
ilan-gold Mar 1, 2024
48a1a1e
(fix): try `Union`
ilan-gold Mar 1, 2024
07fc5ba
(chore): add column division ability
ilan-gold Mar 1, 2024
3cc4be2
(chore): add scale test
ilan-gold Mar 1, 2024
4cc9eef
(fix): duplicate in release note
ilan-gold Mar 1, 2024
d8afe5c
(refactor): guard clause + comments
ilan-gold Mar 1, 2024
271d5d8
(chore): add `out` check for `dask`
ilan-gold Mar 1, 2024
c61324b
(chore): add `divisor` type hints
ilan-gold Mar 1, 2024
c688aff
(fix): remove some erroneous diffs
ilan-gold Mar 1, 2024
02be7a7
(chore): `axis_{sum,mean}` type hint fixes
ilan-gold Mar 1, 2024
6acc08c
(refactor): generalize to scaling
ilan-gold Mar 4, 2024
0944429
(chore): remove erroneous comment
ilan-gold Mar 4, 2024
3538572
(chore): remove non-public API
ilan-gold Mar 4, 2024
5ef1487
(fix): import from `sc._utils`
ilan-gold Mar 4, 2024
0f43362
(fix): `inidices` -> `indices`
ilan-gold Mar 4, 2024
c100a8f
(fix): remove erroneous `axis_sum` calls
ilan-gold Mar 5, 2024
22b4e90
(fix): return statements for `axis_scale`
ilan-gold Mar 5, 2024
4fef58e
(refactor): return out of `axis_sum` if `X._meta` is `np.ndarray`
ilan-gold Mar 5, 2024
ce574e3
(core): comment fix
ilan-gold Mar 5, 2024
e5a82fc
(fix): use `normalize_total` in HVG test for true reproducibility
ilan-gold Mar 5, 2024
a4e53a6
(refactor): separate out `out` test for dask
ilan-gold Mar 6, 2024
f0b2d97
(fix): correct chunking/rechunking behavior
ilan-gold Mar 6, 2024
f9ea93d
(chore): add guard clause for `sparse` `out != X != None` in scaling
ilan-gold Mar 6, 2024
66f04b6
(fix): guard clause condition
ilan-gold Mar 6, 2024
daca210
(fix): try finishing `|` typing for 3.9
ilan-gold Mar 6, 2024
036391e
(fix): call `register` to allow unions?
ilan-gold Mar 6, 2024
cac4160
(fix): clarify warning
ilan-gold Mar 6, 2024
9ec6935
(feat): test for `max_value`/`zero_center` combos
ilan-gold Mar 6, 2024
0ae76ee
(fix): allow settings of `X` in `scale_array`
ilan-gold Mar 6, 2024
2367f46
(chore): add tests for `normalize` correctness
ilan-gold Mar 6, 2024
b2c3a96
(fix): refactor for pure dask in `median`
ilan-gold Mar 6, 2024
340894b
Merge branch 'main' into dask-sparse-mean-var
ilan-gold Mar 6, 2024
fa66f58
(refactor): add clarifying condition
ilan-gold Mar 6, 2024
2601fe8
Merge branch 'dask-sparse-mean-var' of github.com:scverse/scanpy into…
ilan-gold Mar 6, 2024
750af59
(chore): skip warning computations + tests
ilan-gold Mar 6, 2024
25fe1f9
(fix): actually skip computation in `normalize_total` condition
ilan-gold Mar 6, 2024
57c8389
(fix): actually skip in `filter_genes` + tests
ilan-gold Mar 6, 2024
69ebf98
(fix): use all-in-one median implemetation
ilan-gold Mar 7, 2024
67f47f4
(refactor): remove erreous dask warnings
ilan-gold Mar 7, 2024
e328eb5
(chore): add note about `exclude_highly_expressed`
ilan-gold Mar 7, 2024
0aafabd
(feat): `axis_scale` -> `axis_mul_or_truediv`
ilan-gold Mar 7, 2024
be988c9
(feat): `allow_divide_by_zero`
ilan-gold Mar 7, 2024
3166909
(chore): add notes + type hints
ilan-gold Mar 7, 2024
6552324
Have hvg compute earlier and only once
ivirshup Mar 20, 2024
936eb87
Merge branch 'main' into dask-sparse-mean-var
ilan-gold Mar 21, 2024
37eb1a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2024
0cecb94
(refactor): make codecov better by removing dead code/refactoring
ilan-gold Mar 22, 2024
21faa0d
(fix): `np.clip` in dask does not take min/max as `kwargs`
ilan-gold Mar 22, 2024
5998ae8
Update docs/release-notes/1.11.0.md
ilan-gold Mar 22, 2024
f49b929
(chore): move release note
ilan-gold Mar 22, 2024
937c6db
Merge branch 'main' into dask-sparse-mean-var
ilan-gold Mar 22, 2024
ba445f8
(chore): remove erroneous comment
ilan-gold Mar 22, 2024
b3581ea
Merge branch 'dask-sparse-mean-var' of github.com:scverse/scanpy into…
ilan-gold Mar 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1.11.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```{rubric} Features
```
* Support sparse chunks in dask {func}`~scanpy.pp.scale`, {func}`~scanpy.pp.filter_cells`, {func}`~scanpy.pp.filter_genes`, {func}`~scanpy.pp.normalize_total` and {func}`~scanpy.pp.highly_variable_genes` (`seurat` and `cell-ranger` tested) {pr}`2856` {smaller}`ilan-gold`
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved

```{rubric} Docs
```
Expand Down
226 changes: 225 additions & 1 deletion scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
Union,
overload,
)
from weakref import WeakSet

import numpy as np
Expand Down Expand Up @@ -561,6 +570,221 @@ def _elem_mul_dask(x: DaskArray, y: DaskArray) -> DaskArray:
return da.map_blocks(elem_mul, x, y)


Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray)


def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T:
divisor = np.ravel(divisor)
if axis:
return divisor[None, :]
return divisor[:, None]


@singledispatch
def axis_mul_or_truediv(
X: sparse.spmatrix,
scaling_array,
axis: Literal[0, 1],
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: sparse.spmatrix | None = None,
) -> sparse.spmatrix:
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")
if out is not None:
if X.data is not out.data:
raise ValueError(
"`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
)
if not allow_divide_by_zero and op is truediv:
scaling_array = scaling_array.copy() + (scaling_array == 0)

row_scale = axis == 0
column_scale = axis == 1
if row_scale:

def new_data_op(x):
return op(x.data, np.repeat(scaling_array, np.diff(x.indptr)))

elif column_scale:

def new_data_op(x):
return op(x.data, scaling_array.take(x.indices, mode="clip"))

if X.format == "csr":
indices = X.indices
indptr = X.indptr
if out is not None:
X.data = new_data_op(X)
return X
return sparse.csr_matrix(
(new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape
)
transposed = X.T
return axis_mul_or_truediv(
transposed,
scaling_array,
op=op,
axis=1 - axis,
out=transposed,
allow_divide_by_zero=allow_divide_by_zero,
).T


@axis_mul_or_truediv.register(np.ndarray)
def _(
X: np.ndarray,
scaling_array: np.ndarray,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: np.ndarray | None = None,
) -> np.ndarray:
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")
scaling_array = broadcast_axis(scaling_array, axis)
if op is mul:
return np.multiply(X, scaling_array, out=out)
if not allow_divide_by_zero:
scaling_array = scaling_array.copy() + (scaling_array == 0)
return np.true_divide(X, scaling_array, out=out)


def make_axis_chunks(X: DaskArray, axis: Literal[0, 1], pad=True) -> tuple[tuple[int]]:
if axis == 0:
if pad:
return (X.chunks[axis], (1,))
return X.chunks[axis]
if pad:
return ((1,), X.chunks[axis])
return X.chunks[axis]


@axis_mul_or_truediv.register(DaskArray)
def _(
X: DaskArray,
scaling_array: Scaling_T,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: None = None,
) -> DaskArray:
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")
if out is not None:
raise TypeError(
"`out` is not `None`. Do not do in-place modifications on dask arrays."
)

import dask.array as da

scaling_array = broadcast_axis(scaling_array, axis)
row_scale = axis == 0
column_scale = axis == 1

if isinstance(scaling_array, DaskArray):
if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or (
column_scale
and (
(
len(scaling_array.chunksize) == 1
and X.chunksize[1] != scaling_array.chunksize[0]
)
or (
len(scaling_array.chunksize) == 2
and X.chunksize[1] != scaling_array.chunksize[1]
)
)
):
warnings.warn("Rechunking scaling_array in user operation", UserWarning)
scaling_array = scaling_array.rechunk(
make_axis_chunks(X, axis, pad=len(scaling_array.shape) == 2)
)
else:
scaling_array = da.from_array(
scaling_array,
chunks=make_axis_chunks(X, axis, pad=len(scaling_array.shape) == 2),
)
return da.map_blocks(
axis_mul_or_truediv,
X,
scaling_array,
axis,
op,
meta=X._meta,
out=out,
allow_divide_by_zero=allow_divide_by_zero,
)


@overload
def axis_sum(
X: sparse.spmatrix,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.matrix:
...


@singledispatch
def axis_sum(
X: np.ndarray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.ndarray:
return np.sum(X, axis=axis, dtype=dtype)


@axis_sum.register(DaskArray)
def _(
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
X: DaskArray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> DaskArray:
import dask.array as da

if dtype is None:
dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object)

if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix):
return X.sum(axis=axis, dtype=dtype)

def sum_drop_keepdims(*args, **kwargs):
kwargs.pop("computing_meta", None)
Copy link
Contributor

@ilan-gold ilan-gold Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I am not so sure about. It doesn't seem to have an impact and also I'm not sure it's used looking at the dask code: https://docs.dask.org/en/stable/search.html?q=computing_meta

# masked operations on sparse produce which numpy matrices gives the same API issues handled here
if isinstance(X._meta, (sparse.spmatrix, np.matrix)) or isinstance(
args[0], (sparse.spmatrix, np.matrix)
):
kwargs.pop("keepdims", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely something we don't want to run with the sparse matrices and probably (due to interop) the dense either

axis = kwargs["axis"]
if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(
f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead"
)
kwargs["axis"] = axis[0]
# returns a np.matrix normally, which is undesireable
return np.array(np.sum(*args, dtype=dtype, **kwargs))

def aggregate_sum(*args, **kwargs):
return np.sum(args[0], dtype=dtype, **kwargs)

return da.reduction(
X,
sum_drop_keepdims,
aggregate_sum,
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)


@singledispatch
def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray:
"""Checks values of X to ensure it is count data"""
Expand Down
4 changes: 3 additions & 1 deletion scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def materialize_as_ndarray(


def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
a: DaskArray | ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Compute distributed arrays and convert them to numpy ndarrays."""
if isinstance(a, DaskArray):
return a.compute()
if not isinstance(a, tuple):
return np.asarray(a)

Expand Down
6 changes: 2 additions & 4 deletions scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _highly_variable_genes_single_batch(
else:
X = np.expm1(X)

mean, var = _get_mean_var(X)
mean, var = materialize_as_ndarray(_get_mean_var(X))
# now actually compute the dispersion
mean[mean == 0] = 1e-12 # set entries equal to zero to small value
dispersion = var / mean
Expand All @@ -277,9 +277,7 @@ def _highly_variable_genes_single_batch(
mean = np.log1p(mean)

# all of the following quantities are "per-gene" here
df = pd.DataFrame(
dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion))))
)
df = pd.DataFrame(dict(zip(["means", "dispersions"], (mean, dispersion))))
df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins)
disp_stats = _get_disp_stats(df, flavor)

Expand Down
Loading
Loading