diff --git a/.gitignore b/.gitignore index d97bb06..42fb919 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ uv.lock eval_output/ **/celltype*csv .DS_Store +.idea/* ## Custom State-Eval Ignores diff --git a/src/cell_eval/_evaluator.py b/src/cell_eval/_evaluator.py index a924414..0e86ce8 100644 --- a/src/cell_eval/_evaluator.py +++ b/src/cell_eval/_evaluator.py @@ -189,7 +189,7 @@ def _convert_to_normlog( Will skip if the input is not integer data. """ - if guess_is_lognorm(adata=adata): + if guess_is_lognorm(adata=adata, validate=not allow_discrete): logger.info( "Input is found to be log-normalized already - skipping transformation." ) diff --git a/src/cell_eval/utils.py b/src/cell_eval/utils.py index 7ceacfe..f3f5ad1 100644 --- a/src/cell_eval/utils.py +++ b/src/cell_eval/utils.py @@ -1,20 +1,40 @@ +import logging + import anndata as ad import numpy as np from scipy.sparse import csc_matrix, csr_matrix +logger = logging.getLogger(__name__) + def guess_is_lognorm( adata: ad.AnnData, epsilon: float = 1e-3, + max_threshold: float = 15.0, + validate: bool = True, ) -> bool: """Guess if the input is integer counts or log-normalized. - This is an _educated guess_ based on whether the fractional component of cell sums. - This _will not be able_ to distinguish between normalized input and log-normalized input. + This is an _educated guess_ based on whether there is a fractional component of values. + Checks that data with decimal values is in expected log1p range. + + Args: + adata: AnnData object to check + epsilon: Threshold for detecting fractional values (default 1e-3) + max_threshold: Maximum valid value for log1p normalized data (default 15.0) + validate: Whether to validate the data is in valid log1p range (default True) Returns: - bool: True if the input is lognorm, False otherwise + bool: True if the input is lognorm, False if integer counts + + Raises: + ValueError: If data has decimal values but falls outside + valid log1p range (min < 0 or max >= max_threshold), indicating mixed or invalid scales """ + if adata.X is None: + raise ValueError("adata.X is None") + + # Check for fractional values if isinstance(adata.X, csr_matrix) or isinstance(adata.X, csc_matrix): frac, _ = np.modf(adata.X.data) elif adata.isview: @@ -24,7 +44,43 @@ def guess_is_lognorm( else: frac, _ = np.modf(adata.X) # type: ignore - return bool(np.any(frac > epsilon)) + has_decimals = bool(np.any(frac > epsilon)) + + if not has_decimals: + # All integer values - assume raw counts + logger.info("Data appears to be integer counts (no decimal values detected)") + return False + + # Data has decimals - perform validation if requested + # Validate it's in valid log1p range + if isinstance(adata.X, csr_matrix) or isinstance(adata.X, csc_matrix): + max_val = adata.X.max() + min_val = adata.X.min() + else: + max_val = float(np.max(adata.X)) + min_val = float(np.min(adata.X)) + + # Validate range + if min_val < 0: + raise ValueError( + f"Invalid scale: min value {min_val:.2f} is negative. " + f"Both Natural or Log1p normalized data must have all values >= 0." + ) + + if validate and max_val >= max_threshold: + raise ValueError( + f"Invalid scale: max value {max_val:.2f} exceeds log1p threshold of {max_threshold}. " + f"Expected log1p normalized values in range [0, {max_threshold}), but found values suggesting " + f"raw counts or incorrect normalization. Values above {max_threshold} indicate mixed scales " + f"(some cells with raw counts, some with log1p values)." + ) + + # Valid log1p data + logger.info( + f"Data appears to be log1p normalized (decimals detected, range [{min_val:.2f}, {max_val:.2f}])" + ) + + return True def split_anndata_on_celltype( diff --git a/tests/test_eval.py b/tests/test_eval.py index 958c95c..968a701 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -91,6 +91,28 @@ def test_broken_adata_not_normlog_skip_check(): ) +def test_broken_adata_invalid_pred_scale(): + """Test that predicted data with invalid scale is rejected.""" + adata_real = build_random_anndata(normlog=True) + adata_pred = adata_real.copy() + + # Create invalid predicted data: mix of raw counts and log1p + adata_pred.X = np.random.uniform( + 0, + 5000, + size=adata_pred.X.shape, # type: ignore + ) + + with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"): + MetricsEvaluator( + adata_pred=adata_pred, + adata_real=adata_real, + control_pert=CONTROL_VAR, + pert_col=PERT_COL, + outdir=OUTDIR, + ) + + def test_broken_adata_missing_pertcol_in_real(): adata_real = build_random_anndata() adata_pred = adata_real.copy() diff --git a/tests/test_utils.py b/tests/test_utils.py index d64b550..95cafa3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,6 @@ +import numpy as np +import pytest + from cell_eval.data import build_random_anndata from cell_eval.utils import guess_is_lognorm @@ -16,3 +19,103 @@ def test_is_lognorm_view(): def test_is_lognorm_false(): data = build_random_anndata(normlog=False) assert not guess_is_lognorm(data) + + +def test_guess_is_lognorm_valid_lognorm(): + """Test that valid log1p normalized data returns True.""" + data = build_random_anndata(normlog=True, random_state=42) + # Should return True without raising exception + assert ( + guess_is_lognorm( + data, + ) + is True + ) + + +def test_guess_is_lognorm_valid_lognorm_sparse(): + """Test that valid log1p normalized sparse data returns True.""" + data = build_random_anndata(normlog=True, as_sparse=True, random_state=42) + # Should return True without raising exception + assert ( + guess_is_lognorm( + data, + ) + is True + ) + + +def test_guess_is_lognorm_integer_data(): + """Test that integer data (raw counts) returns False.""" + data = build_random_anndata(normlog=False, random_state=42) + # Should return False - integer data indicates raw counts + assert ( + guess_is_lognorm( + data, + ) + is False + ) + + +def test_guess_is_lognorm_edge_case_near_threshold(): + """Test that values near but below threshold return True.""" + data = build_random_anndata(normlog=True, random_state=42) + # Modify data to have values near threshold (10.9) + data.X = np.random.uniform( + 0, + 14.9, + size=data.X.shape, # type: ignore + ) + # Should return True without raising exception + assert ( + guess_is_lognorm( + data, + ) + is True + ) + + +def test_guess_is_lognorm_exceeds_threshold(): + """Test that data with max value > 11.0 raises ValueError when .""" + data = build_random_anndata(normlog=True, random_state=42) + # Modify data to exceed threshold (mix of valid and invalid) + data.X = np.random.uniform( + 0, + 15.1, + size=data.X.shape, # type: ignore + ) + + with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"): + guess_is_lognorm( + data, + ) + + +def test_guess_is_lognorm_negative_values(): + """Test that data with negative values raises ValueError when .""" + data = build_random_anndata(normlog=True, random_state=42) + # Modify data to include negative values + data.X = np.random.uniform( + -1, + 9, + size=data.X.shape, # type: ignore + ) + + with pytest.raises(ValueError, match="Invalid scale.*is negative"): + guess_is_lognorm( + data, + ) + + +def test_guess_is_lognorm_mixed_scales(): + """Test mixed scenario: some cells with raw counts, some with log1p.""" + data = build_random_anndata(normlog=True, random_state=42) + n_cells = data.X.shape[0] # type: ignore + half = n_cells // 2 + data.X[:half] = np.random.uniform(0, 9, size=(half, data.X.shape[1])) # type: ignore + data.X[half:] = np.random.uniform(100, 5000, size=(n_cells - half, data.X.shape[1])) # type: ignore + + with pytest.raises(ValueError, match="Invalid scale.*exceeds log1p threshold"): + guess_is_lognorm( + data, + )