Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uv.lock
eval_output/
**/celltype*csv
.DS_Store
.idea/*


## Custom State-Eval Ignores
Expand Down
2 changes: 1 addition & 1 deletion src/cell_eval/_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _convert_to_normlog(
Will skip if the input is not integer data.
"""
if guess_is_lognorm(adata=adata):
if guess_is_lognorm(adata=adata, validate=not allow_discrete):
logger.info(
"Input is found to be log-normalized already - skipping transformation."
)
Expand Down
64 changes: 60 additions & 4 deletions src/cell_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
import logging

import anndata as ad
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix

logger = logging.getLogger(__name__)


def guess_is_lognorm(
adata: ad.AnnData,
epsilon: float = 1e-3,
max_threshold: float = 15.0,
validate: bool = True,
) -> bool:
"""Guess if the input is integer counts or log-normalized.

This is an _educated guess_ based on whether the fractional component of cell sums.
This _will not be able_ to distinguish between normalized input and log-normalized input.
This is an _educated guess_ based on whether there is a fractional component of values.
Checks that data with decimal values is in expected log1p range.

Args:
adata: AnnData object to check
epsilon: Threshold for detecting fractional values (default 1e-3)
max_threshold: Maximum valid value for log1p normalized data (default 15.0)
validate: Whether to validate the data is in valid log1p range (default True)

Returns:
bool: True if the input is lognorm, False otherwise
bool: True if the input is lognorm, False if integer counts

Raises:
ValueError: If data has decimal values but falls outside
valid log1p range (min < 0 or max >= max_threshold), indicating mixed or invalid scales
"""
if adata.X is None:
raise ValueError("adata.X is None")

# Check for fractional values
if isinstance(adata.X, csr_matrix) or isinstance(adata.X, csc_matrix):
frac, _ = np.modf(adata.X.data)
elif adata.isview:
Expand All @@ -24,7 +44,43 @@ def guess_is_lognorm(
else:
frac, _ = np.modf(adata.X) # type: ignore

return bool(np.any(frac > epsilon))
has_decimals = bool(np.any(frac > epsilon))

if not has_decimals:
# All integer values - assume raw counts
logger.info("Data appears to be integer counts (no decimal values detected)")
return False

# Data has decimals - perform validation if requested
# Validate it's in valid log1p range
if isinstance(adata.X, csr_matrix) or isinstance(adata.X, csc_matrix):
max_val = adata.X.max()
min_val = adata.X.min()
else:
max_val = float(np.max(adata.X))
min_val = float(np.min(adata.X))

# Validate range
if min_val < 0:
raise ValueError(
f"Invalid scale: min value {min_val:.2f} is negative. "
f"Both Natural or Log1p normalized data must have all values >= 0."
)

if validate and max_val >= max_threshold:
raise ValueError(
f"Invalid scale: max value {max_val:.2f} exceeds log1p threshold of {max_threshold}. "
f"Expected log1p normalized values in range [0, {max_threshold}), but found values suggesting "
f"raw counts or incorrect normalization. Values above {max_threshold} indicate mixed scales "
f"(some cells with raw counts, some with log1p values)."
)

# Valid log1p data
logger.info(
f"Data appears to be log1p normalized (decimals detected, range [{min_val:.2f}, {max_val:.2f}])"
)

return True


def split_anndata_on_celltype(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ def test_broken_adata_not_normlog_skip_check():
)


def test_broken_adata_invalid_pred_scale():
"""Test that predicted data with invalid scale is rejected."""
adata_real = build_random_anndata(normlog=True)
adata_pred = adata_real.copy()

# Create invalid predicted data: mix of raw counts and log1p
adata_pred.X = np.random.uniform(
0,
5000,
size=adata_pred.X.shape, # type: ignore
)

with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"):
MetricsEvaluator(
adata_pred=adata_pred,
adata_real=adata_real,
control_pert=CONTROL_VAR,
pert_col=PERT_COL,
outdir=OUTDIR,
)


def test_broken_adata_missing_pertcol_in_real():
adata_real = build_random_anndata()
adata_pred = adata_real.copy()
Expand Down
103 changes: 103 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np
import pytest

from cell_eval.data import build_random_anndata
from cell_eval.utils import guess_is_lognorm

Expand All @@ -16,3 +19,103 @@ def test_is_lognorm_view():
def test_is_lognorm_false():
data = build_random_anndata(normlog=False)
assert not guess_is_lognorm(data)


def test_guess_is_lognorm_valid_lognorm():
"""Test that valid log1p normalized data returns True."""
data = build_random_anndata(normlog=True, random_state=42)
# Should return True without raising exception
assert (
guess_is_lognorm(
data,
)
is True
)


def test_guess_is_lognorm_valid_lognorm_sparse():
"""Test that valid log1p normalized sparse data returns True."""
data = build_random_anndata(normlog=True, as_sparse=True, random_state=42)
# Should return True without raising exception
assert (
guess_is_lognorm(
data,
)
is True
)


def test_guess_is_lognorm_integer_data():
"""Test that integer data (raw counts) returns False."""
data = build_random_anndata(normlog=False, random_state=42)
# Should return False - integer data indicates raw counts
assert (
guess_is_lognorm(
data,
)
is False
)


def test_guess_is_lognorm_edge_case_near_threshold():
"""Test that values near but below threshold return True."""
data = build_random_anndata(normlog=True, random_state=42)
# Modify data to have values near threshold (10.9)
data.X = np.random.uniform(
0,
14.9,
size=data.X.shape, # type: ignore
)
# Should return True without raising exception
assert (
guess_is_lognorm(
data,
)
is True
)


def test_guess_is_lognorm_exceeds_threshold():
"""Test that data with max value > 11.0 raises ValueError when ."""
data = build_random_anndata(normlog=True, random_state=42)
# Modify data to exceed threshold (mix of valid and invalid)
data.X = np.random.uniform(
0,
15.1,
size=data.X.shape, # type: ignore
)

with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"):
guess_is_lognorm(
data,
)


def test_guess_is_lognorm_negative_values():
"""Test that data with negative values raises ValueError when ."""
data = build_random_anndata(normlog=True, random_state=42)
# Modify data to include negative values
data.X = np.random.uniform(
-1,
9,
size=data.X.shape, # type: ignore
)

with pytest.raises(ValueError, match="Invalid scale.*is negative"):
guess_is_lognorm(
data,
)


def test_guess_is_lognorm_mixed_scales():
"""Test mixed scenario: some cells with raw counts, some with log1p."""
data = build_random_anndata(normlog=True, random_state=42)
n_cells = data.X.shape[0] # type: ignore
half = n_cells // 2
data.X[:half] = np.random.uniform(0, 9, size=(half, data.X.shape[1])) # type: ignore
data.X[half:] = np.random.uniform(100, 5000, size=(n_cells - half, data.X.shape[1])) # type: ignore

with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"):
guess_is_lognorm(
data,
)