Skip to content

Commit

Permalink
Backport PR scverse#2856: (feat): pre-processing functions for dask
Browse files Browse the repository at this point in the history
… with sparse chunks
  • Loading branch information
flying-sheep authored and meeseeksmachine committed Mar 22, 2024
1 parent 7c1e4cc commit 6f32147
Show file tree
Hide file tree
Showing 13 changed files with 651 additions and 95 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.10.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
* `scanpy.pp.calculate_qc_metrics` now allows `qc_vars` to be passed as a string {pr}`2859` {smaller}`N Teyssier`
* {func}`scanpy.tl.leiden` and {func}`scanpy.tl.louvain` now store clustering parameters in the key provided by the `key_added` parameter instead of always writing to (or overwriting) a default key {pr}`2864` {smaller}`J Fan`
* {func}`scanpy.pp.scale` now clips `np.ndarray` also at `- max_value` for zero-centering {pr}`2913` {smaller}`S Dicks`
* Support sparse chunks in dask {func}`~scanpy.pp.scale`, {func}`~scanpy.pp.normalize_total` and {func}`~scanpy.pp.highly_variable_genes` (`seurat` and `cell-ranger` tested) {pr}`2856` {smaller}`ilan-gold`

```{rubric} Docs
```
Expand Down
223 changes: 222 additions & 1 deletion scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
Union,
overload,
)
from weakref import WeakSet

import numpy as np
Expand Down Expand Up @@ -561,6 +570,218 @@ def _elem_mul_dask(x: DaskArray, y: DaskArray) -> DaskArray:
return da.map_blocks(elem_mul, x, y)


Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray)


def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T:
divisor = np.ravel(divisor)
if axis:
return divisor[None, :]
return divisor[:, None]


def check_op(op):
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")


@singledispatch
def axis_mul_or_truediv(
X: sparse.spmatrix,
scaling_array,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: sparse.spmatrix | None = None,
) -> sparse.spmatrix:
check_op(op)
if out is not None:
if X.data is not out.data:
raise ValueError(
"`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
)
if not allow_divide_by_zero and op is truediv:
scaling_array = scaling_array.copy() + (scaling_array == 0)

row_scale = axis == 0
column_scale = axis == 1
if row_scale:

def new_data_op(x):
return op(x.data, np.repeat(scaling_array, np.diff(x.indptr)))

elif column_scale:

def new_data_op(x):
return op(x.data, scaling_array.take(x.indices, mode="clip"))

if X.format == "csr":
indices = X.indices
indptr = X.indptr
if out is not None:
X.data = new_data_op(X)
return X
return sparse.csr_matrix(
(new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape
)
transposed = X.T
return axis_mul_or_truediv(
transposed,
scaling_array,
op=op,
axis=1 - axis,
out=transposed,
allow_divide_by_zero=allow_divide_by_zero,
).T


@axis_mul_or_truediv.register(np.ndarray)
def _(
X: np.ndarray,
scaling_array: np.ndarray,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: np.ndarray | None = None,
) -> np.ndarray:
check_op(op)
scaling_array = broadcast_axis(scaling_array, axis)
if op is mul:
return np.multiply(X, scaling_array, out=out)
if not allow_divide_by_zero:
scaling_array = scaling_array.copy() + (scaling_array == 0)
return np.true_divide(X, scaling_array, out=out)


def make_axis_chunks(
X: DaskArray, axis: Literal[0, 1], pad=True
) -> tuple[tuple[int], tuple[int]]:
if axis == 0:
return (X.chunks[axis], (1,))
return ((1,), X.chunks[axis])


@axis_mul_or_truediv.register(DaskArray)
def _(
X: DaskArray,
scaling_array: Scaling_T,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: None = None,
) -> DaskArray:
check_op(op)
if out is not None:
raise TypeError(
"`out` is not `None`. Do not do in-place modifications on dask arrays."
)

import dask.array as da

scaling_array = broadcast_axis(scaling_array, axis)
row_scale = axis == 0
column_scale = axis == 1

if isinstance(scaling_array, DaskArray):
if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or (
column_scale
and (
(
len(scaling_array.chunksize) == 1
and X.chunksize[1] != scaling_array.chunksize[0]
)
or (
len(scaling_array.chunksize) == 2
and X.chunksize[1] != scaling_array.chunksize[1]
)
)
):
warnings.warn("Rechunking scaling_array in user operation", UserWarning)
scaling_array = scaling_array.rechunk(make_axis_chunks(X, axis))
else:
scaling_array = da.from_array(
scaling_array,
chunks=make_axis_chunks(X, axis),
)
return da.map_blocks(
axis_mul_or_truediv,
X,
scaling_array,
axis,
op,
meta=X._meta,
out=out,
allow_divide_by_zero=allow_divide_by_zero,
)


@overload
def axis_sum(
X: sparse.spmatrix,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.matrix: ...


@singledispatch
def axis_sum(
X: np.ndarray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.ndarray:
return np.sum(X, axis=axis, dtype=dtype)


@axis_sum.register(DaskArray)
def _(
X: DaskArray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> DaskArray:
import dask.array as da

if dtype is None:
dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object)

if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix):
return X.sum(axis=axis, dtype=dtype)

def sum_drop_keepdims(*args, **kwargs):
kwargs.pop("computing_meta", None)
# masked operations on sparse produce which numpy matrices gives the same API issues handled here
if isinstance(X._meta, (sparse.spmatrix, np.matrix)) or isinstance(
args[0], (sparse.spmatrix, np.matrix)
):
kwargs.pop("keepdims", None)
axis = kwargs["axis"]
if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(
f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead"
)
kwargs["axis"] = axis[0]
# returns a np.matrix normally, which is undesireable
return np.array(np.sum(*args, dtype=dtype, **kwargs))

def aggregate_sum(*args, **kwargs):
return np.sum(args[0], dtype=dtype, **kwargs)

return da.reduction(
X,
sum_drop_keepdims,
aggregate_sum,
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)


@singledispatch
def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray:
"""Checks values of X to ensure it is count data"""
Expand Down
4 changes: 3 additions & 1 deletion scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ def materialize_as_ndarray(


def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
a: DaskArray | ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Compute distributed arrays and convert them to numpy ndarrays."""
if isinstance(a, DaskArray):
return a.compute()
if not isinstance(a, tuple):
return np.asarray(a)

Expand Down
6 changes: 2 additions & 4 deletions scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _highly_variable_genes_single_batch(
else:
X = np.expm1(X)

mean, var = _get_mean_var(X)
mean, var = materialize_as_ndarray(_get_mean_var(X))
# now actually compute the dispersion
mean[mean == 0] = 1e-12 # set entries equal to zero to small value
dispersion = var / mean
Expand All @@ -277,9 +277,7 @@ def _highly_variable_genes_single_batch(
mean = np.log1p(mean)

# all of the following quantities are "per-gene" here
df = pd.DataFrame(
dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion))))
)
df = pd.DataFrame(dict(zip(["means", "dispersions"], (mean, dispersion))))
df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins)
disp_stats = _get_disp_stats(df, flavor)

Expand Down
Loading

0 comments on commit 6f32147

Please sign in to comment.