Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): pre-processing functions for dask with sparse chunks #2856

Merged
merged 96 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
d3163cc
(chore): add dask sparse chunks creation
ilan-gold Feb 27, 2024
2a7a54c
(feat): add dask summation
ilan-gold Feb 27, 2024
8dd9a7a
(refactor): `materialize_as_ndarray` needs to operate on indidiual da…
ilan-gold Feb 27, 2024
d55b6a4
(feat): `filter_genes` and `filter_cells`
ilan-gold Feb 27, 2024
41a5f15
(feat): normalization
ilan-gold Feb 27, 2024
e36699e
(fix) `lop1p` tests working
ilan-gold Feb 27, 2024
da6eff0
(refactor): clean up writing test
ilan-gold Feb 27, 2024
63ca2f0
(refactor): us `da.chunk.sum`
ilan-gold Feb 27, 2024
fd22a19
(fix): remove `Client`
ilan-gold Feb 27, 2024
8b9a792
(refactor): remove unnecessary `count_nonzero`
ilan-gold Feb 27, 2024
1592571
(fix): change expected fail on sparse normalization
ilan-gold Feb 27, 2024
6ac32e5
(fix): update comment
ilan-gold Feb 28, 2024
78a3ab6
(feat): `_get_mean_var` dask
ilan-gold Feb 28, 2024
60bbdb8
(feat): clean up tests for what should/should not work
ilan-gold Feb 28, 2024
0c1f254
(refactor): `_compat.sum` to `_utils.elem_sum`
ilan-gold Feb 28, 2024
2f4d11a
(chore): add `elem_sum` test
ilan-gold Feb 28, 2024
12502e8
(refactor): `elem_sum` -> `axis_sum`
ilan-gold Feb 29, 2024
b3bb95a
(feat): add `scale` support
ilan-gold Feb 29, 2024
2b6f717
(fix): maintain dtype
ilan-gold Feb 29, 2024
448dc40
(chore): add back condition
ilan-gold Feb 29, 2024
7226bf0
(fix): use `sum` when needed
ilan-gold Feb 29, 2024
2bc7c3a
(chore): release notes
ilan-gold Feb 29, 2024
62c75fc
(fx): don't use `mean_func` name twice
ilan-gold Feb 29, 2024
b987a68
(chore): revert sparse-chunks-in-dask
ilan-gold Feb 29, 2024
902238a
(chore): type hint
ilan-gold Feb 29, 2024
0abbab5
(chore): check `test_compare_to_upstream`
ilan-gold Feb 29, 2024
a8606ae
(chore): remove comment
ilan-gold Feb 29, 2024
feac6bc
(chore): allow passing `dtype` arg in `axis_sum`
ilan-gold Feb 29, 2024
bcdeddb
(fix): revert fixture changes
ilan-gold Feb 29, 2024
4716d8f
(refactor): cleaner with `array_type` conversion before if-then
ilan-gold Feb 29, 2024
3912b63
(chore): clarify hvg support
ilan-gold Feb 29, 2024
c884c63
(chore): handle array types better
ilan-gold Mar 1, 2024
af351d4
(chore): clean up `materialize_as_ndarray`
ilan-gold Mar 1, 2024
da22953
(chore): fix typing/dispatch problem in 3.9
ilan-gold Mar 1, 2024
dbbc6a2
(chore): `list` type -> `Callable`
ilan-gold Mar 1, 2024
6a4f0c5
(feat): `row_divide` for better division handling
ilan-gold Mar 1, 2024
c0182cb
(fix): use `tuple` for `ARRAY_TYPEXXXX`
ilan-gold Mar 1, 2024
a065a78
(refactor): `mean_func` -> `axis_mean` + types
ilan-gold Mar 1, 2024
c3ee138
(chore): remove unnecessary aggregation
ilan-gold Mar 1, 2024
c246a41
(fix): raise `ValueError` for summing over more than one axis
ilan-gold Mar 1, 2024
743c327
(fix): grammar
ilan-gold Mar 1, 2024
db88560
(fix): better type hints
ilan-gold Mar 1, 2024
2a5faa6
(revert): use old `test_normalize_total` siince we have `csr`
ilan-gold Mar 1, 2024
d6ceb4c
(revert): extraneous diff
ilan-gold Mar 1, 2024
48a1a1e
(fix): try `Union`
ilan-gold Mar 1, 2024
07fc5ba
(chore): add column division ability
ilan-gold Mar 1, 2024
3cc4be2
(chore): add scale test
ilan-gold Mar 1, 2024
4cc9eef
(fix): duplicate in release note
ilan-gold Mar 1, 2024
d8afe5c
(refactor): guard clause + comments
ilan-gold Mar 1, 2024
271d5d8
(chore): add `out` check for `dask`
ilan-gold Mar 1, 2024
c61324b
(chore): add `divisor` type hints
ilan-gold Mar 1, 2024
c688aff
(fix): remove some erroneous diffs
ilan-gold Mar 1, 2024
02be7a7
(chore): `axis_{sum,mean}` type hint fixes
ilan-gold Mar 1, 2024
6acc08c
(refactor): generalize to scaling
ilan-gold Mar 4, 2024
0944429
(chore): remove erroneous comment
ilan-gold Mar 4, 2024
3538572
(chore): remove non-public API
ilan-gold Mar 4, 2024
5ef1487
(fix): import from `sc._utils`
ilan-gold Mar 4, 2024
0f43362
(fix): `inidices` -> `indices`
ilan-gold Mar 4, 2024
c100a8f
(fix): remove erroneous `axis_sum` calls
ilan-gold Mar 5, 2024
22b4e90
(fix): return statements for `axis_scale`
ilan-gold Mar 5, 2024
4fef58e
(refactor): return out of `axis_sum` if `X._meta` is `np.ndarray`
ilan-gold Mar 5, 2024
ce574e3
(core): comment fix
ilan-gold Mar 5, 2024
e5a82fc
(fix): use `normalize_total` in HVG test for true reproducibility
ilan-gold Mar 5, 2024
a4e53a6
(refactor): separate out `out` test for dask
ilan-gold Mar 6, 2024
f0b2d97
(fix): correct chunking/rechunking behavior
ilan-gold Mar 6, 2024
f9ea93d
(chore): add guard clause for `sparse` `out != X != None` in scaling
ilan-gold Mar 6, 2024
66f04b6
(fix): guard clause condition
ilan-gold Mar 6, 2024
daca210
(fix): try finishing `|` typing for 3.9
ilan-gold Mar 6, 2024
036391e
(fix): call `register` to allow unions?
ilan-gold Mar 6, 2024
cac4160
(fix): clarify warning
ilan-gold Mar 6, 2024
9ec6935
(feat): test for `max_value`/`zero_center` combos
ilan-gold Mar 6, 2024
0ae76ee
(fix): allow settings of `X` in `scale_array`
ilan-gold Mar 6, 2024
2367f46
(chore): add tests for `normalize` correctness
ilan-gold Mar 6, 2024
b2c3a96
(fix): refactor for pure dask in `median`
ilan-gold Mar 6, 2024
340894b
Merge branch 'main' into dask-sparse-mean-var
ilan-gold Mar 6, 2024
fa66f58
(refactor): add clarifying condition
ilan-gold Mar 6, 2024
2601fe8
Merge branch 'dask-sparse-mean-var' of github.com:scverse/scanpy into…
ilan-gold Mar 6, 2024
750af59
(chore): skip warning computations + tests
ilan-gold Mar 6, 2024
25fe1f9
(fix): actually skip computation in `normalize_total` condition
ilan-gold Mar 6, 2024
57c8389
(fix): actually skip in `filter_genes` + tests
ilan-gold Mar 6, 2024
69ebf98
(fix): use all-in-one median implemetation
ilan-gold Mar 7, 2024
67f47f4
(refactor): remove erreous dask warnings
ilan-gold Mar 7, 2024
e328eb5
(chore): add note about `exclude_highly_expressed`
ilan-gold Mar 7, 2024
0aafabd
(feat): `axis_scale` -> `axis_mul_or_truediv`
ilan-gold Mar 7, 2024
be988c9
(feat): `allow_divide_by_zero`
ilan-gold Mar 7, 2024
3166909
(chore): add notes + type hints
ilan-gold Mar 7, 2024
6552324
Have hvg compute earlier and only once
ivirshup Mar 20, 2024
936eb87
Merge branch 'main' into dask-sparse-mean-var
ilan-gold Mar 21, 2024
37eb1a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2024
0cecb94
(refactor): make codecov better by removing dead code/refactoring
ilan-gold Mar 22, 2024
21faa0d
(fix): `np.clip` in dask does not take min/max as `kwargs`
ilan-gold Mar 22, 2024
5998ae8
Update docs/release-notes/1.11.0.md
ilan-gold Mar 22, 2024
f49b929
(chore): move release note
ilan-gold Mar 22, 2024
937c6db
Merge branch 'main' into dask-sparse-mean-var
ilan-gold Mar 22, 2024
ba445f8
(chore): remove erroneous comment
ilan-gold Mar 22, 2024
b3581ea
Merge branch 'dask-sparse-mean-var' of github.com:scverse/scanpy into…
ilan-gold Mar 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1.11.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```{rubric} Features
```
* Support sparse chunks in dask {func}`~scanpy.pp.scale`, {func}`~scanpy.pp.filter_cells`, {func}`~scanpy.pp.filter_genes`, {func}`~scanpy.pp.normalize_total` and {func}`~scanpy.pp.highly_variable_genes` (`seurat` and `cell-ranger` tested) {pr}`2856` {smaller}`ilan-gold`
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved

```{rubric} Docs
```
Expand Down
226 changes: 225 additions & 1 deletion scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, Any, Callable, Literal, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
Union,
overload,
)
from weakref import WeakSet

import numpy as np
Expand Down Expand Up @@ -560,6 +569,221 @@
return da.map_blocks(elem_mul, x, y)


Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray)


def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T:
divisor = np.ravel(divisor)
if axis:
return divisor[None, :]
return divisor[:, None]


@singledispatch
def axis_mul_or_truediv(
X: sparse.spmatrix,
scaling_array,
axis: Literal[0, 1],
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: sparse.spmatrix | None = None,
) -> sparse.spmatrix:
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")

Check warning on line 593 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L593

Added line #L593 was not covered by tests
if out is not None:
if X.data is not out.data:
raise ValueError(

Check warning on line 596 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L596

Added line #L596 was not covered by tests
"`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
)
if not allow_divide_by_zero and op is truediv:
scaling_array = scaling_array.copy() + (scaling_array == 0)

row_scale = axis == 0
column_scale = axis == 1
if row_scale:

def new_data_op(x):
return op(x.data, np.repeat(scaling_array, np.diff(x.indptr)))

elif column_scale:

def new_data_op(x):
return op(x.data, scaling_array.take(x.indices, mode="clip"))

if X.format == "csr":
indices = X.indices
indptr = X.indptr
if out is not None:
X.data = new_data_op(X)
return X
return sparse.csr_matrix(
(new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape
)
transposed = X.T
return axis_mul_or_truediv(
transposed,
scaling_array,
op=op,
axis=1 - axis,
out=transposed,
allow_divide_by_zero=allow_divide_by_zero,
).T


@axis_mul_or_truediv.register(np.ndarray)
def _(
X: np.ndarray,
scaling_array: np.ndarray,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: np.ndarray | None = None,
) -> np.ndarray:
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")

Check warning on line 645 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L645

Added line #L645 was not covered by tests
scaling_array = broadcast_axis(scaling_array, axis)
if op is mul:
return np.multiply(X, scaling_array, out=out)
if not allow_divide_by_zero:
scaling_array = scaling_array.copy() + (scaling_array == 0)
return np.true_divide(X, scaling_array, out=out)


def make_axis_chunks(X: DaskArray, axis: Literal[0, 1], pad=True) -> tuple[tuple[int]]:
if axis == 0:
if pad:
return (X.chunks[axis], (1,))
return X.chunks[axis]
if pad:
return ((1,), X.chunks[axis])
return X.chunks[axis]

Check warning on line 661 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L655-L661

Added lines #L655 - L661 were not covered by tests


@axis_mul_or_truediv.register(DaskArray)
def _(
X: DaskArray,
scaling_array: Scaling_T,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: None = None,
) -> DaskArray:
if op not in {truediv, mul}:
raise ValueError(f"{op} not one of truediv or mul")
if out is not None:
raise TypeError(

Check warning on line 677 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L674-L677

Added lines #L674 - L677 were not covered by tests
"`out` is not `None`. Do not do in-place modifications on dask arrays."
)

import dask.array as da

Check warning on line 681 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L681

Added line #L681 was not covered by tests

scaling_array = broadcast_axis(scaling_array, axis)
row_scale = axis == 0
column_scale = axis == 1

Check warning on line 685 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L683-L685

Added lines #L683 - L685 were not covered by tests

if isinstance(scaling_array, DaskArray):
if (row_scale and not X.chunksize[0] == scaling_array.chunksize[0]) or (

Check warning on line 688 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L687-L688

Added lines #L687 - L688 were not covered by tests
column_scale
and (
(
len(scaling_array.chunksize) == 1
and X.chunksize[1] != scaling_array.chunksize[0]
)
or (
len(scaling_array.chunksize) == 2
and X.chunksize[1] != scaling_array.chunksize[1]
)
)
):
warnings.warn("Rechunking scaling_array in user operation", UserWarning)
scaling_array = scaling_array.rechunk(

Check warning on line 702 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L701-L702

Added lines #L701 - L702 were not covered by tests
make_axis_chunks(X, axis, pad=len(scaling_array.shape) == 2)
)
else:
scaling_array = da.from_array(

Check warning on line 706 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L706

Added line #L706 was not covered by tests
scaling_array,
chunks=make_axis_chunks(X, axis, pad=len(scaling_array.shape) == 2),
)
return da.map_blocks(

Check warning on line 710 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L710

Added line #L710 was not covered by tests
axis_mul_or_truediv,
X,
scaling_array,
axis,
op,
meta=X._meta,
out=out,
allow_divide_by_zero=allow_divide_by_zero,
)


@overload
def axis_sum(
X: sparse.spmatrix,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.matrix:
...

Check warning on line 729 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L729

Added line #L729 was not covered by tests


@singledispatch
def axis_sum(
X: np.ndarray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> np.ndarray:
return np.sum(X, axis=axis, dtype=dtype)


@axis_sum.register(DaskArray)
def _(
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
X: DaskArray,
*,
axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None,
dtype: np.typing.DTypeLike | None = None,
) -> DaskArray:
import dask.array as da

Check warning on line 749 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L749

Added line #L749 was not covered by tests

if dtype is None:
dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object)

Check warning on line 752 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L751-L752

Added lines #L751 - L752 were not covered by tests

if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix):
return X.sum(axis=axis, dtype=dtype)

Check warning on line 755 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L754-L755

Added lines #L754 - L755 were not covered by tests

def sum_drop_keepdims(*args, **kwargs):
kwargs.pop("computing_meta", None)

Check warning on line 758 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L757-L758

Added lines #L757 - L758 were not covered by tests
Copy link
Contributor

@ilan-gold ilan-gold Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one I am not so sure about. It doesn't seem to have an impact and also I'm not sure it's used looking at the dask code: https://docs.dask.org/en/stable/search.html?q=computing_meta

# masked operations on sparse produce which numpy matrices gives the same API issues handled here
if isinstance(X._meta, (sparse.spmatrix, np.matrix)) or isinstance(

Check warning on line 760 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L760

Added line #L760 was not covered by tests
args[0], (sparse.spmatrix, np.matrix)
):
kwargs.pop("keepdims", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely something we don't want to run with the sparse matrices and probably (due to interop) the dense either

axis = kwargs["axis"]
if isinstance(axis, tuple):
if len(axis) != 1:
raise ValueError(

Check warning on line 767 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L763-L767

Added lines #L763 - L767 were not covered by tests
f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead"
)
kwargs["axis"] = axis[0]

Check warning on line 770 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L770

Added line #L770 was not covered by tests
# returns a np.matrix normally, which is undesireable
return np.array(np.sum(*args, dtype=dtype, **kwargs))

Check warning on line 772 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L772

Added line #L772 was not covered by tests

def aggregate_sum(*args, **kwargs):
return np.sum(args[0], dtype=dtype, **kwargs)

Check warning on line 775 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L774-L775

Added lines #L774 - L775 were not covered by tests

return da.reduction(

Check warning on line 777 in scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

scanpy/_utils/__init__.py#L777

Added line #L777 was not covered by tests
X,
sum_drop_keepdims,
aggregate_sum,
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)


@singledispatch
def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray:
"""Checks values of X to ensure it is count data"""
Expand Down
4 changes: 3 additions & 1 deletion scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@


def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
a: DaskArray | ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Compute distributed arrays and convert them to numpy ndarrays."""
if isinstance(a, DaskArray):
return a.compute()

Check warning on line 42 in scanpy/preprocessing/_distributed.py

View check run for this annotation

Codecov / codecov/patch

scanpy/preprocessing/_distributed.py#L42

Added line #L42 was not covered by tests
if not isinstance(a, tuple):
return np.asarray(a)

Expand Down
66 changes: 44 additions & 22 deletions scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from __future__ import annotations

from operator import truediv
from typing import TYPE_CHECKING, Literal
from warnings import warn

import numpy as np
from scipy.sparse import issparse
from sklearn.utils import sparsefuncs

from .. import logging as logg
from .._compat import DaskArray, old_positionals
from .._utils import view_to_actual
from .._utils import axis_mul_or_truediv, axis_sum, view_to_actual
from ..get import _get_obs_rep, _set_obs_rep

try:
import dask
import dask.array as da
except ImportError:
da = None
dask = None

Check warning on line 20 in scanpy/preprocessing/_normalization.py

View check run for this annotation

Codecov / codecov/patch

scanpy/preprocessing/_normalization.py#L18-L20

Added lines #L18 - L20 were not covered by tests

if TYPE_CHECKING:
from collections.abc import Iterable

Expand All @@ -22,21 +29,30 @@
X = X.copy() if copy else X
if issubclass(X.dtype.type, (int, np.integer)):
X = X.astype(np.float32) # TODO: Check if float64 should be used
if isinstance(counts, DaskArray):
counts_greater_than_zero = counts[counts > 0].compute_chunk_sizes()
else:
counts_greater_than_zero = counts[counts > 0]
if after is None:
if isinstance(counts, DaskArray):

def nonzero_median(x):
return np.ma.median(np.ma.masked_array(x, x == 0)).item()

Check warning on line 36 in scanpy/preprocessing/_normalization.py

View check run for this annotation

Codecov / codecov/patch

scanpy/preprocessing/_normalization.py#L35-L36

Added lines #L35 - L36 were not covered by tests

after = np.median(counts_greater_than_zero, axis=0) if after is None else after
counts += counts == 0
after = da.from_delayed(

Check warning on line 38 in scanpy/preprocessing/_normalization.py

View check run for this annotation

Codecov / codecov/patch

scanpy/preprocessing/_normalization.py#L38

Added line #L38 was not covered by tests
dask.delayed(nonzero_median)(counts),
shape=(),
meta=counts._meta,
dtype=counts.dtype,
)
else:
counts_greater_than_zero = counts[counts > 0]
after = np.median(counts_greater_than_zero, axis=0)
counts = counts / after
if issparse(X):
sparsefuncs.inplace_row_scale(X, 1 / counts)
elif isinstance(counts, np.ndarray):
np.divide(X, counts[:, None], out=X)
else:
X = np.divide(X, counts[:, None]) # dask does not support kwarg "out"
return X
return axis_mul_or_truediv(
X,
counts,
op=truediv,
out=X if isinstance(X, np.ndarray) or issparse(X) else None,
allow_divide_by_zero=False,
axis=0,
)


@old_positionals(
Expand Down Expand Up @@ -78,6 +94,11 @@
Similar functions are used, for example, by Seurat [Satija15]_, Cell Ranger
[Zheng17]_ or SPRING [Weinreb17]_.

.. note::
When used with a :class:`~dask.array.Array` in `adata.X`, this function will have to
call functions that trigger `.compute()` on the :class:`~dask.array.Array` if `exclude_highly_expressed`
is `True`, `layer_norm` is not `None`, or if `key_added` is not `None`.

Params
------
adata
Expand All @@ -92,7 +113,8 @@
normalization factor (size factor) for each cell. A gene is considered
highly expressed, if it has more than `max_fraction` of the total counts
in at least one cell. The not-excluded genes will sum up to
`target_sum`.
`target_sum`. Providing this argument when `adata.X` is a :class:`~dask.array.Array`
will incur blocking `.compute()` calls on the array.
max_fraction
If `exclude_highly_expressed=True`, consider cells as highly expressed
that have more counts than `max_fraction` of the original total counts
Expand Down Expand Up @@ -187,27 +209,27 @@

gene_subset = None
msg = "normalizing counts per cell"

counts_per_cell = axis_sum(X, axis=1)
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
if exclude_highly_expressed:
counts_per_cell = X.sum(1) # original counts per cell
counts_per_cell = np.ravel(counts_per_cell)

# at least one cell as more than max_fraction of counts per cell

gene_subset = (X > counts_per_cell[:, None] * max_fraction).sum(0)
gene_subset = axis_sum((X > counts_per_cell[:, None] * max_fraction), axis=0)
gene_subset = np.asarray(np.ravel(gene_subset) == 0)

msg += (
". The following highly-expressed genes are not considered during "
f"normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}"
)
counts_per_cell = X[:, gene_subset].sum(1)
else:
counts_per_cell = X.sum(1)
counts_per_cell = axis_sum(X[:, gene_subset], axis=1)

start = logg.info(msg)
counts_per_cell = np.ravel(counts_per_cell)

cell_subset = counts_per_cell > 0
if not np.all(cell_subset):
if not isinstance(cell_subset, DaskArray) and not np.all(cell_subset):
warn(UserWarning("Some cells have zero counts"))

if inplace:
Expand Down
Loading
Loading