Skip to content

Commit

Permalink
(feat): pre-processing functions for dask with sparse chunks (#2856)
Browse files Browse the repository at this point in the history
* (chore): add dask sparse chunks creation

* (feat): add dask summation

* (refactor): `materialize_as_ndarray` needs to operate on indidiual dask arrays

* (feat): `filter_genes` and `filter_cells`

* (feat): normalization

* (fix) `lop1p` tests working

* (refactor): clean up writing test

* (refactor): us `da.chunk.sum`

* (fix): remove `Client`

* (refactor): remove unnecessary `count_nonzero`

* (fix): change expected fail on sparse normalization

* (fix): update comment

* (feat): `_get_mean_var` dask

* (feat): clean up tests for what should/should not work

* (refactor): `_compat.sum` to `_utils.elem_sum`

* (chore): add `elem_sum` test

* (refactor): `elem_sum` -> `axis_sum`

* (feat): add `scale` support

* (fix): maintain dtype

* (chore): add back condition

* (fix): use `sum` when needed

* (chore): release notes

* (fx): don't use `mean_func` name twice

* (chore): revert sparse-chunks-in-dask

* (chore): type hint

* (chore): check `test_compare_to_upstream`

* (chore): remove comment

* (chore): allow passing `dtype` arg in `axis_sum`

* (fix): revert fixture changes

* (refactor): cleaner with `array_type` conversion before if-then

* (chore): clarify hvg support

* (chore): handle array types better

* (chore): clean up `materialize_as_ndarray`

* (chore): fix typing/dispatch problem in 3.9

* (chore): `list` type -> `Callable`

* (feat): `row_divide` for better division handling

* (fix): use `tuple` for `ARRAY_TYPEXXXX`

* (refactor): `mean_func` -> `axis_mean` + types

* (chore): remove unnecessary aggregation

* (fix): raise `ValueError` for summing over more than one axis

* (fix): grammar

* (fix): better type hints

* (revert): use old `test_normalize_total` siince we have `csr`

* (revert): extraneous diff

* (fix): try `Union`

* (chore): add column division ability

* (chore): add scale test

* (fix): duplicate in release note

* (refactor): guard clause + comments

* (chore): add `out` check for `dask`

* (chore): add `divisor` type hints

* (fix): remove some erroneous diffs

* (chore): `axis_{sum,mean}` type hint fixes

* (refactor): generalize to scaling

* (chore): remove erroneous comment

* (chore): remove non-public API

* (fix): import from `sc._utils`

* (fix): `inidices` -> `indices`

* (fix): remove erroneous `axis_sum` calls

* (fix): return statements for `axis_scale`

* (refactor): return out of `axis_sum` if `X._meta` is `np.ndarray`

* (core): comment fix

* (fix): use `normalize_total` in HVG test for true reproducibility

* (refactor): separate out `out` test for dask

* (fix): correct chunking/rechunking behavior

* (chore): add guard clause for `sparse` `out != X != None` in scaling

* (fix): guard clause condition

* (fix): try finishing `|` typing for 3.9

* (fix): call `register` to allow unions?

* (fix): clarify warning

* (feat): test for `max_value`/`zero_center` combos

* (fix): allow settings of `X` in `scale_array`

* (chore): add tests for `normalize` correctness

* (fix): refactor for pure dask in `median`

* (refactor): add clarifying condition

* (chore): skip warning computations + tests

* (fix): actually skip computation in `normalize_total` condition

* (fix): actually skip in `filter_genes` + tests

* (fix): use all-in-one median implemetation

* (refactor): remove erreous dask warnings

* (chore): add note about `exclude_highly_expressed`

* (feat): `axis_scale` -> `axis_mul_or_truediv`

* (feat): `allow_divide_by_zero`

* (chore): add notes + type hints

* Have hvg compute earlier and only once

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (refactor): make codecov better by removing dead code/refactoring

* (fix): `np.clip` in dask does not take min/max as `kwargs`

* Update docs/release-notes/1.11.0.md

Co-authored-by: Isaac Virshup <[email protected]>

* (chore): move release note

* (chore): remove erroneous comment

---------

Co-authored-by: ilan-gold <[email protected]>
Co-authored-by: Isaac Virshup <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 22, 2024
1 parent 921fcca commit 4b757d8
Show file tree
Hide file tree
Showing 13 changed files with 651 additions and 95 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.10.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
* `scanpy.pp.calculate_qc_metrics` now allows `qc_vars` to be passed as a string {pr}`2859` {smaller}`N Teyssier`
* {func}`scanpy.tl.leiden` and {func}`scanpy.tl.louvain` now store clustering parameters in the key provided by the `key_added` parameter instead of always writing to (or overwriting) a default key {pr}`2864` {smaller}`J Fan`
* {func}`scanpy.pp.scale` now clips `np.ndarray` also at `- max_value` for zero-centering {pr}`2913` {smaller}`S Dicks`
* Support sparse chunks in dask {func}`~scanpy.pp.scale`, {func}`~scanpy.pp.normalize_total` and {func}`~scanpy.pp.highly_variable_genes` (`seurat` and `cell-ranger` tested) {pr}`2856` {smaller}`ilan-gold`

```{rubric} Docs
```
Expand Down
223 changes: 222 additions & 1 deletion scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
Union,
overload,
)
from weakref import WeakSet

import numpy as np
Expand Down Expand Up @@ -542,6 +551,218 @@ def _elem_mul_dask(x: DaskArray, y: DaskArray) -> DaskArray:
return da.map_blocks(elem_mul, x, y)


Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray)


def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T:
divisor = np.ravel(divisor)
if axis:
return divisor[None, :]
return divisor[:, None]


def check_op(op):
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")


@singledispatch
def axis_mul_or_truediv(
X: sparse.spmatrix,
scaling_array,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: sparse.spmatrix | None = None,
) -> sparse.spmatrix:
check_op(op)
if out is not None:
if X.data is not out.data:
raise ValueError(
"`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
)
if not allow_divide_by_zero and op is truediv:
scaling_array = scaling_array.copy() + (scaling_array == 0)

row_scale = axis == 0
column_scale = axis == 1
if row_scale:

def new_data_op(x):
return op(x.data, np.repeat(scaling_array, np.diff(x.indptr)))

elif column_scale:

def new_data_op(x):
return op(x.data, scaling_array.take(x.indices, mode="clip"))

if X.format == "csr":
indices = X.indices
indptr = X.indptr
if out is not None:
X.data = new_data_op(X)
return X
return sparse.csr_matrix(
(new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape
)
transposed = X.T
return axis_mul_or_truediv(
transposed,
scaling_array,
op=op,
axis=1 - axis,
out=transposed,
allow_divide_by_zero=allow_divide_by_zero,
).T


@axis_mul_or_truediv.register(np.ndarray)
def _(
X: np.ndarray,
scaling_array: np.ndarray,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: np.ndarray | None = None,
) -> np.ndarray:
check_op(op)
scaling_array = broadcast_axis(scaling_array, axis)
if op is mul:
return np.multiply(X, scaling_array, out=out)
if not allow_divide_by_zero:
scaling_array = scaling_array.copy() + (scaling_array == 0)
return np.true_divide(X, scaling_array, out=out)


def make_axis_chunks(
X: DaskArray, axis: Literal[0, 1], pad=True
) -> tuple[tuple[int], tuple[int]]:
if axis == 0:
return (X.chunks[axis], (1,))
return ((1,), X.chunks[axis])


@axis_mul_or_truediv.register(DaskArray)
def _(
X: DaskArray,
scaling_array: Scaling_T,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: None = None,
) -> DaskArray:
check_op(op)
if out is not None:
raise TypeError(
"`out` is not `None`. Do not do in-place modifications on dask arrays."
)

import dask.array as da

scaling_array = broadcast_axis(scaling_array, axis)
row_scale = axis == 0
column_scale = axis == 1

if isinstance(scaling_array, DaskArray):
if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or (
column_scale
and (
(
len(scaling_array.chunksize) == 1
and X.chunksize[1] != scaling_array.chunksize[0]
)
or (
len(scaling_array.chunksize) == 2
and X.chunksize[1] != scaling_array.chunksize[1]
)
)
):
warnings.warn("Rechunking scaling_array in user operation", UserWarning)
scaling_array = scaling_array.rechunk(make_axis_chunks(X, axis))
else:
scaling_array = da.from_array(
scaling_array,
chunks=make_axis_chunks(X, axis),
)
return da.map_blocks(
axis_mul_or_truediv,
X,
scaling_array,
axis,
op,
meta=X._meta,
out=out,
allow_divide_by_zero=allow_divide_by_zero,
)


@overload
def axis_sum(
X: sparse.spmatrix,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.matrix: ...


@singledispatch
def axis_sum(
X: np.ndarray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.ndarray:
return np.sum(X, axis=axis, dtype=dtype)


@axis_sum.register(DaskArray)
def _(
X: DaskArray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> DaskArray:
import dask.array as da

if dtype is None:
dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object)

if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix):
return X.sum(axis=axis, dtype=dtype)

def sum_drop_keepdims(*args, **kwargs):
kwargs.pop("computing_meta", None)
# masked operations on sparse produce which numpy matrices gives the same API issues handled here
if isinstance(X._meta, (sparse.spmatrix, np.matrix)) or isinstance(
args[0], (sparse.spmatrix, np.matrix)
):
kwargs.pop("keepdims", None)
axis = kwargs["axis"]
if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(
f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead"
)
kwargs["axis"] = axis[0]
# returns a np.matrix normally, which is undesireable
return np.array(np.sum(*args, dtype=dtype, **kwargs))

def aggregate_sum(*args, **kwargs):
return np.sum(args[0], dtype=dtype, **kwargs)

return da.reduction(
X,
sum_drop_keepdims,
aggregate_sum,
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)


@singledispatch
def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray:
"""Checks values of X to ensure it is count data"""
Expand Down
4 changes: 3 additions & 1 deletion scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def materialize_as_ndarray(


def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
a: DaskArray | ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Compute distributed arrays and convert them to numpy ndarrays."""
if isinstance(a, DaskArray):
return a.compute()
if not isinstance(a, tuple):
return np.asarray(a)

Expand Down
6 changes: 2 additions & 4 deletions scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _highly_variable_genes_single_batch(
else:
X = np.expm1(X)

mean, var = _get_mean_var(X)
mean, var = materialize_as_ndarray(_get_mean_var(X))
# now actually compute the dispersion
mean[mean == 0] = 1e-12 # set entries equal to zero to small value
dispersion = var / mean
Expand All @@ -277,9 +277,7 @@ def _highly_variable_genes_single_batch(
mean = np.log1p(mean)

# all of the following quantities are "per-gene" here
df = pd.DataFrame(
dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion))))
)
df = pd.DataFrame(dict(zip(["means", "dispersions"], (mean, dispersion))))
df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins)
disp_stats = _get_disp_stats(df, flavor)

Expand Down
Loading

0 comments on commit 4b757d8

Please sign in to comment.