diff --git a/src/cellrank/estimators/terminal_states/_gpcca.py b/src/cellrank/estimators/terminal_states/_gpcca.py index 1d4affa74..2b525a6d2 100644 --- a/src/cellrank/estimators/terminal_states/_gpcca.py +++ b/src/cellrank/estimators/terminal_states/_gpcca.py @@ -563,7 +563,7 @@ def tsi( """ tsi_precomputed = (self._tsi is not None) and (self._tsi[:, "number_of_macrostates"].X.max() >= n_macrostates) if terminal_states is not None: - tsi_precomputed = tsi_precomputed and (self._tsi.uns["terminal_states"] == set(terminal_states)) + tsi_precomputed = tsi_precomputed and (set(self._tsi.uns["terminal_states"]) == set(terminal_states)) if cluster_key is not None: tsi_precomputed = tsi_precomputed and (self._tsi.uns["cluster_key"] == cluster_key) @@ -590,9 +590,7 @@ def tsi( tsi_df["optimal_identification"].append(min(n_states, max_terminal_states)) - tsi_df = AnnData( - pd.DataFrame(tsi_df), uns={"terminal_states": set(terminal_states), "cluster_key": cluster_key} - ) + tsi_df = AnnData(pd.DataFrame(tsi_df), uns={"terminal_states": terminal_states, "cluster_key": cluster_key}) self._tsi = tsi_df tsi_df = self._tsi.to_df()