Skip to content

Commit

Permalink
(feat): add dask summation in filter_genes
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Feb 27, 2024
1 parent 59e546a commit 987c7d5
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
46 changes: 45 additions & 1 deletion scanpy/_compat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass, field
from functools import partial
from functools import partial, singledispatch
from pathlib import Path

import dask.array as da
import numpy as np
from dask.distributed import Client
from legacy_api_wrap import legacy_api
from packaging import version
from scipy import sparse as sp

client = Client(processes=False) # Starts local cluster
try:
from functools import cache
except ImportError: # Python < 3.9
Expand Down Expand Up @@ -81,3 +86,42 @@ def pkg_version(package):


old_positionals = partial(legacy_api, category=FutureWarning)


@singledispatch
def sum(X: np.ndarray | sp.spmatrix, axis=None):
return np.sum(X, axis=axis)


@sum.register
def _(X: da.Array, axis=None):
def sum_drop_keepdims(*args, **kwargs):
kwargs.pop("computing_meta", None)
if isinstance(X._meta, sp.spmatrix): # bad! why are we getting np matrices??
kwargs.pop("keepdims", None)
if isinstance(kwargs["axis"], tuple):
kwargs["axis"] = kwargs["axis"][0]
return da.chunk.sum(*args, **kwargs)

dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object)

# operates on `np.matrix` for some reason with sparse chunks in dask so need explicit casting
def aggregate_sum(*args, **kwargs):
return np.sum(np.array(args[0]), **kwargs)

return da.reduction(X, sum_drop_keepdims, aggregate_sum, axis=axis, dtype=dtype)


@singledispatch
def count_nonzero(X: np.ndarray, axis=None):
return np.count_nonzero(X, axis=axis)


@count_nonzero.register
def _(X: da.Array, axis=None):
return sum(X > 0, axis=axis)


@count_nonzero.register
def _(X: sp.spmatrix, axis=None):
return X.getnnz(axis=axis)
9 changes: 3 additions & 6 deletions scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Compute distributed arrays and convert them to numpy ndarrays."""
if isinstance(a, DaskArray):
return a.compute()
if not isinstance(a, tuple):
return np.asarray(a)

if not any(isinstance(arr, DaskArray) for arr in a):
return tuple(np.asarray(arr) for arr in a)

import dask.array as da

return da.compute(*a, sync=True)
return tuple(materialize_as_ndarray(arr) for arr in a)
26 changes: 12 additions & 14 deletions scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.sparse import csr_matrix, issparse, isspmatrix_coo, isspmatrix_csr, spmatrix
from sklearn.utils.sparsefuncs import mean_variance_axis

from .._compat import count_nonzero, sum
from .._utils import _doc_params
from ._docs import (
doc_adata_basic,
Expand Down Expand Up @@ -103,15 +104,12 @@ def describe_obs(
if issparse(X):
X.eliminate_zeros()
obs_metrics = pd.DataFrame(index=adata.obs_names)
if issparse(X):
obs_metrics[f"n_{var_type}_by_{expr_type}"] = X.getnnz(axis=1)
else:
obs_metrics[f"n_{var_type}_by_{expr_type}"] = np.count_nonzero(X, axis=1)
obs_metrics[f"n_{var_type}_by_{expr_type}"] = count_nonzero(X, axis=1)
if log1p:
obs_metrics[f"log1p_n_{var_type}_by_{expr_type}"] = np.log1p(
obs_metrics[f"n_{var_type}_by_{expr_type}"]
)
obs_metrics[f"total_{expr_type}"] = np.ravel(X.sum(axis=1))
obs_metrics[f"total_{expr_type}"] = np.ravel(sum(X, axis=1))
if log1p:
obs_metrics[f"log1p_total_{expr_type}"] = np.log1p(
obs_metrics[f"total_{expr_type}"]
Expand All @@ -125,7 +123,7 @@ def describe_obs(
)
for qc_var in qc_vars:
obs_metrics[f"total_{expr_type}_{qc_var}"] = np.ravel(
X[:, adata.var[qc_var].values].sum(axis=1)
sum(X[:, adata.var[qc_var].values], axis=1)
)
if log1p:
obs_metrics[f"log1p_total_{expr_type}_{qc_var}"] = np.log1p(
Expand Down Expand Up @@ -204,7 +202,7 @@ def describe_var(
var_metrics["pct_dropout_by_{expr_type}"] = (
1 - var_metrics["n_cells_by_{expr_type}"] / X.shape[0]
) * 100
var_metrics["total_{expr_type}"] = np.ravel(X.sum(axis=0))
var_metrics["total_{expr_type}"] = np.ravel(sum(X, axis=0))
if log1p:
var_metrics["log1p_total_{expr_type}"] = np.log1p(
var_metrics["total_{expr_type}"]
Expand Down Expand Up @@ -358,7 +356,7 @@ def top_proportions(mtx: np.ndarray | spmatrix, n: int):


def top_proportions_dense(mtx, n):
sums = mtx.sum(axis=1)
sums = sum(mtx, axis=1)
partitioned = np.apply_along_axis(np.argpartition, 1, -mtx, n - 1)
partitioned = partitioned[:, :n]
values = np.zeros_like(partitioned, dtype=np.float64)
Expand All @@ -377,10 +375,10 @@ def top_proportions_sparse_csr(data, indptr, n):
vec = np.zeros(n, dtype=np.float64)
if end - start <= n:
vec[: end - start] = data[start:end]
total = vec.sum()
total = sum(vec)
else:
vec[:] = -(np.partition(-data[start:end], n - 1)[:n])
total = (data[start:end]).sum() # Is this not just vec.sum()?
total = sum(data[start:end]) # Is this not just vec.sum()?
vec[::-1].sort()
values[i, :] = vec.cumsum() / total
return values
Expand Down Expand Up @@ -417,15 +415,15 @@ def top_segment_proportions_dense(
) -> np.ndarray:
# Currently ns is considered to be 1 indexed
ns = np.sort(ns)
sums = mtx.sum(axis=1)
sums = sum(mtx, axis=1)
partitioned = np.apply_along_axis(np.partition, 1, mtx, mtx.shape[1] - ns)[:, ::-1][
:, : ns[-1]
]
values = np.zeros((mtx.shape[0], len(ns)))
acc = np.zeros(mtx.shape[0])
prev = 0
for j, n in enumerate(ns):
acc += partitioned[:, prev:n].sum(axis=1)
acc += sum(partitioned[:, prev:n], axis=1)
values[:, j] = acc
prev = n
return values / sums[:, None]
Expand All @@ -444,7 +442,7 @@ def top_segment_proportions_sparse_csr(data, indptr, ns):
partitioned = np.zeros((indptr.size - 1, maxidx), dtype=data.dtype)
for i in numba.prange(indptr.size - 1):
start, end = indptr[i], indptr[i + 1]
sums[i] = np.sum(data[start:end])
sums[i] = sum(data[start:end])
if end - start <= maxidx:
partitioned[i, : end - start] = data[start:end]
elif (end - start) > maxidx:
Expand All @@ -455,7 +453,7 @@ def top_segment_proportions_sparse_csr(data, indptr, ns):
prev = 0
# can’t use enumerate due to https://github.com/numba/numba/issues/2625
for j in range(ns.size):
acc += partitioned[:, prev : ns[j]].sum(axis=1)
acc += sum(partitioned[:, prev : ns[j]], axis=1)
values[:, j] = acc
prev = ns[j]
return values / sums.reshape((indptr.size - 1, 1))
6 changes: 3 additions & 3 deletions scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.utils import check_array, sparsefuncs

from .. import logging as logg
from .._compat import old_positionals
from .._compat import old_positionals, sum
from .._settings import settings as sett
from .._utils import (
AnyRandom,
Expand Down Expand Up @@ -278,7 +278,7 @@ def filter_genes(
X = data # proceed with processing the data matrix
min_number = min_counts if min_cells is None else min_cells
max_number = max_counts if max_cells is None else max_cells
number_per_gene = np.sum(
number_per_gene = sum(
X if min_cells is None and max_cells is None else X > 0, axis=0
)
if issparse(X):
Expand All @@ -288,7 +288,7 @@ def filter_genes(
if max_number is not None:
gene_subset = number_per_gene <= max_number

s = np.sum(~gene_subset)
s = sum(~gene_subset)
if s > 0:
msg = f"filtered out {s} genes that are detected "
if min_cells is not None or min_counts is not None:
Expand Down
7 changes: 5 additions & 2 deletions scanpy/tests/test_preprocessing_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import zarr
from anndata import AnnData, read_zarr
from anndata.experimental import read_elem, sparse_dataset
from scipy import sparse as sp

from scanpy._compat import DaskArray, ZappyArray
from scanpy.datasets._utils import filter_oldformatwarning
Expand Down Expand Up @@ -77,7 +78,7 @@ def adata_dist(request: pytest.FixtureRequest) -> AnnData:
assert request.param == "dask"
import dask.array as da

a.X = da.from_zarr(input_file_X)
a.X = da.from_zarr(input_file_X, chunks=(100, 1000))
return a


Expand Down Expand Up @@ -150,7 +151,9 @@ def test_filter_genes(adata: AnnData, adata_dist: AnnData):
result = materialize_as_ndarray(adata_dist.X)
filter_genes(adata, min_cells=2)
assert result.shape == adata.shape
npt.assert_allclose(result, adata.X)
npt.assert_allclose(
result.toarray() if isinstance(result, sp.spmatrix) else result, adata.X
)


@filter_oldformatwarning
Expand Down

0 comments on commit 987c7d5

Please sign in to comment.