From 87f1fe0606a9af6b0111de50f642e560cde7c519 Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Sat, 2 Mar 2024 12:13:01 +0000 Subject: [PATCH] Update `tsi` Save `terminal_states` in `uns` slot of `tsi` attribute as list. --- src/cellrank/estimators/terminal_states/_gpcca.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/cellrank/estimators/terminal_states/_gpcca.py b/src/cellrank/estimators/terminal_states/_gpcca.py index 1d4affa74..2b525a6d2 100644 --- a/src/cellrank/estimators/terminal_states/_gpcca.py +++ b/src/cellrank/estimators/terminal_states/_gpcca.py @@ -563,7 +563,7 @@ def tsi( """ tsi_precomputed = (self._tsi is not None) and (self._tsi[:, "number_of_macrostates"].X.max() >= n_macrostates) if terminal_states is not None: - tsi_precomputed = tsi_precomputed and (self._tsi.uns["terminal_states"] == set(terminal_states)) + tsi_precomputed = tsi_precomputed and (set(self._tsi.uns["terminal_states"]) == set(terminal_states)) if cluster_key is not None: tsi_precomputed = tsi_precomputed and (self._tsi.uns["cluster_key"] == cluster_key) @@ -590,9 +590,7 @@ def tsi( tsi_df["optimal_identification"].append(min(n_states, max_terminal_states)) - tsi_df = AnnData( - pd.DataFrame(tsi_df), uns={"terminal_states": set(terminal_states), "cluster_key": cluster_key} - ) + tsi_df = AnnData(pd.DataFrame(tsi_df), uns={"terminal_states": terminal_states, "cluster_key": cluster_key}) self._tsi = tsi_df tsi_df = self._tsi.to_df()