diff --git a/src/nichepca/workflows/_nichepca.py b/src/nichepca/workflows/_nichepca.py index a9d92e7..9678d16 100644 --- a/src/nichepca/workflows/_nichepca.py +++ b/src/nichepca/workflows/_nichepca.py @@ -11,7 +11,7 @@ resolve_graph_constructor, ) from nichepca.nhood_embedding import aggregate -from nichepca.utils import check_for_raw_counts, normalize_per_sample +from nichepca.utils import check_for_raw_counts, normalize_per_sample, to_numpy if TYPE_CHECKING: from anndata import AnnData @@ -74,19 +74,22 @@ def nichepca( ------- None """ + # make sure pipeline is an iterable + if isinstance(pipeline, str): + pipeline = [pipeline] + # we always need to use agg assert "agg" in pipeline, "aggregation must be part of the pipeline" + # assert that the pca is behind norm and log1p - if "pca" in pipeline: - pca_after_norm = np.argmax(np.array(pipeline) == "pca") > np.argmax( - np.array(pipeline) == "norm" - ) - pca_after_log1p = np.argmax(np.array(pipeline) == "pca") > np.argmax( - np.array(pipeline) == "log1p" - ) + if "pca" in pipeline and ("norm" in pipeline or "log1p" in pipeline): + pca_index = np.argmax(np.array(pipeline) == "pca") + norm_index = np.argmax(np.array(pipeline) == "norm") + log1p_index = np.argmax(np.array(pipeline) == "log1p") + # argmax returns 0 if not found assert ( - pca_after_norm and pca_after_log1p - ), "pca must be executed after norm and log1p" + norm_index <= pca_index and log1p_index <= pca_index + ), "PCA must be executed after both norm and log1p." # perform sanity check in case we are normalizing the data if "norm" or "log1p" in pipeline and obs_key is None and obsm_key is None: @@ -170,14 +173,17 @@ def nichepca( # extract the results and remove old keys if "X_pca_harmony" in ad_tmp.obsm: X_npca = ad_tmp.obsm["X_pca_harmony"] - else: + elif "X_pca" in ad_tmp.obsm: X_npca = ad_tmp.obsm["X_pca"] + else: + X_npca = to_numpy(ad_tmp.X) # store the results adata.obsm["X_npca"] = X_npca - adata.uns["npca"] = ad_tmp.uns["pca"] - adata.uns["npca"]["PCs"] = pd.DataFrame( - data=ad_tmp.varm["PCs"], - index=ad_tmp.var_names, - columns=[f"PC{i}" for i in range(n_comps)], - ) + if "pca" in pipeline: + adata.uns["npca"] = ad_tmp.uns["pca"] + adata.uns["npca"]["PCs"] = pd.DataFrame( + data=ad_tmp.varm["PCs"], + index=ad_tmp.var_names, + columns=[f"PC{i}" for i in range(n_comps)], + ) diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 645153b..86ce3d7 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -54,6 +54,20 @@ def test_nichepca_single(): assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"]) + # test without pca + pipeline = "agg" + + adata = generate_dummy_adata() + npc.wf.nichepca(adata, knn=5, pipeline=pipeline) + X_npca_0 = adata.obsm["X_npca"] + + adata = generate_dummy_adata() + npc.gc.knn_graph(adata, knn=5) + npc.ne.aggregate(adata) + X_npca_1 = npc.utils.to_numpy(adata.X) + + assert np.all(X_npca_0 == X_npca_1) + def test_nichepca_multi_sample(): adata_1 = generate_dummy_adata()