Skip to content

Commit

Permalink
(feat): normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Feb 27, 2024
1 parent cd2a4d5 commit e83dac5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
11 changes: 6 additions & 5 deletions scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.utils import sparsefuncs

from .. import logging as logg
from .._compat import DaskArray, old_positionals
from .._compat import DaskArray, old_positionals, sum
from .._utils import view_to_actual
from ..get import _get_obs_rep, _set_obs_rep

Expand All @@ -35,7 +35,7 @@ def _normalize_data(X, counts, after=None, copy: bool = False):
elif isinstance(counts, np.ndarray):
np.divide(X, counts[:, None], out=X)
else:
X = np.divide(X, counts[:, None]) # dask does not support kwarg "out"
X = X / counts[:, None]
return X


Expand Down Expand Up @@ -188,20 +188,20 @@ def normalize_total(
gene_subset = None
msg = "normalizing counts per cell"

counts_per_cell = X.sum(1)
counts_per_cell = sum(X, axis=1)
if exclude_highly_expressed:
counts_per_cell = np.ravel(counts_per_cell)

# at least one cell as more than max_fraction of counts per cell

gene_subset = (X > counts_per_cell[:, None] * max_fraction).sum(0)
gene_subset = sum((X > counts_per_cell[:, None] * max_fraction), axis=0)
gene_subset = np.asarray(np.ravel(gene_subset) == 0)

msg += (
". The following highly-expressed genes are not considered during "
f"normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}"
)
counts_per_cell = X[:, gene_subset].sum(1)
counts_per_cell = sum(X[:, gene_subset], axis=1)

start = logg.info(msg)
counts_per_cell = np.ravel(counts_per_cell)
Expand Down Expand Up @@ -253,3 +253,4 @@ def normalize_total(
return adata
elif not inplace:
return dat
return None
10 changes: 8 additions & 2 deletions scanpy/tests/test_preprocessing_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def test_normalize_per_cell(
result = materialize_as_ndarray(adata_dist.X)
normalize_per_cell(adata)
assert result.shape == adata.shape
npt.assert_allclose(result, adata.X)
npt.assert_allclose(
result.toarray() if isinstance(result, sp.spmatrix) else result, adata.X
)


def test_normalize_total(adata: AnnData, adata_dist: AnnData):
Expand All @@ -111,7 +113,11 @@ def test_normalize_total(adata: AnnData, adata_dist: AnnData):
result = materialize_as_ndarray(adata_dist.X)
normalize_total(adata)
assert result.shape == adata.shape
npt.assert_allclose(result, adata.X)
npt.assert_allclose(
result.toarray() if isinstance(result, sp.spmatrix) else result,
adata.X,
atol=1e-6,
)


def test_filter_cells_array(adata: AnnData, adata_dist: AnnData):
Expand Down

0 comments on commit e83dac5

Please sign in to comment.