diff --git a/src/cellrank/kernels/_base_kernel.py b/src/cellrank/kernels/_base_kernel.py index a521f14ae..792a44132 100644 --- a/src/cellrank/kernels/_base_kernel.py +++ b/src/cellrank/kernels/_base_kernel.py @@ -463,6 +463,131 @@ def _reuse_cache(self, expected_params: Dict[str, Any], *, time: Optional[Any] = self._params = expected_params # fmt: on + def _get_boundary(self, source: str, target: str, cluster_key: str, graph_key: str = "distances") -> List[int]: + """Identify source observations at boundary to target cluster. + + Parameters + ---------- + source + Name of source cluster. + target + Name of target cluster. + cluster_key + Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations. + graph_key + Name of graph representation to use from :attr:`~anndata.AnnData.obsp`. + + Returns + ------- + List of observation IDs at boundary to target cluster. + """ + source_obs_mask = self.adata.obs[cluster_key].isin([source] if isinstance(source, str) else source) + target_obs_mask = self.adata.obs[cluster_key].isin([target] if isinstance(target, str) else target) + + source_ids = np.where(source_obs_mask)[0] + boundary_ids = [] + + graph = self.adata.obsp[graph_key] + for source_id in source_ids: + obs_mask = graph[source_id, :].toarray().squeeze().astype(bool) + + if (obs_mask & target_obs_mask).any(): + boundary_ids.append(source_id) + + return boundary_ids + + def _get_empirical_velocity_field( + self, boundary_ids: List[int], target_obs_mask, rep: str, graph_key: str = "distances" + ) -> np.ndarray: + """Compute an emprical estimate of velocity field between two clusters. + + Parameters + ---------- + boundary_ids + List of observation IDs at boundary to target cluster. + target_obs_mask + Boolean indicator identifying relevant observations from target. + graph_key + Name of graph representation to use from :attr:`~anndata.AnnData.obsp`. + + Returns + ------- + Empirical velocity estimate. + """ + obs_ids = np.arange(0, self.adata.n_obs) + graph = self.adata.obsp[graph_key] + features = self.adata.obsm[rep] + empirical_velo = np.empty(shape=(len(boundary_ids), features.shape[1])) + + for idx, boundary_id in enumerate(boundary_ids): + row = graph[boundary_id, :].toarray().squeeze() + obs_mask = row.astype(bool) & target_obs_mask + neighbors = obs_ids[obs_mask] + weights = row[obs_mask] + + empirical_velo[idx, :] = np.sum( + weights.reshape(-1, 1) * (features[neighbors, :] - features[boundary_id, :]), axis=0 + ) + + empirical_velo = np.array(empirical_velo) + obs_mask = np.isnan(empirical_velo).any(axis=1) + empirical_velo = empirical_velo[~obs_mask, :] + + return empirical_velo + + def _get_vector_field_estimate(self, rep: str) -> np.ndarray: + """Compute estimate of vector field under one step of the transition matrix. + + Parameters + ---------- + rep + Key in :attr:`~anndata.AnnData.obsm` to use as data representation. + + Returns + ------- + Vector field estimate based on kernel dynamics. + """ + extrapolated_gex = self.transition_matrix @ self.adata.obsm[rep] + return extrapolated_gex - self.adata.obsm[rep] + + # TODO: Add definition/reference to paper + def cbc(self, source: str, target: str, cluster_key: str, rep: str, graph_key: str = "distances") -> np.ndarray: + """Compute cross-boundary correctness score between source and target cluster. + + Parameters + ---------- + source + Name of the source cluster. + target + Name of the target cluster. + cluster_key + Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations. + rep + Key in :attr:`~anndata.AnnData.obsm` to use as data representation. + graph_key + Name of graph representation to use from :attr:`~anndata.AnnData.obsp`. + + Returns + ------- + Cross-boundary correctness score for each observation. + """ + + def _pearsonr(x: np.ndarray, y: np.ndarray) -> np.ndarray: + x_centered = x - np.mean(x, axis=1, keepdims=True) + y_centered = y - np.mean(y, axis=1, keepdims=True) + denom = np.linalg.norm(x_centered, axis=1) * np.linalg.norm(y_centered, axis=1) + + return np.sum(x_centered * y_centered, axis=1) / denom + + target_obs_mask = self.adata.obs[cluster_key].isin([target]) + boundary_ids = self._get_boundary(source=source, target=target, cluster_key=cluster_key, graph_key=graph_key) + empirical_velo = self._get_empirical_velocity_field( + boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph_key=graph_key + ) + estimated_velo = self._get_vector_field_estimate(rep=rep)[boundary_ids, :] + + return _pearsonr(x=estimated_velo, y=empirical_velo) + @d.dedent class Kernel(KernelExpression, abc.ABC): diff --git a/tests/_ground_truth_adatas/adata_50.h5ad b/tests/_ground_truth_adatas/adata_50.h5ad index 60ab43550..43426fdb4 100644 Binary files a/tests/_ground_truth_adatas/adata_50.h5ad and b/tests/_ground_truth_adatas/adata_50.h5ad differ diff --git a/tests/test_kernels.py b/tests/test_kernels.py index d6b476c8f..ee0d510e3 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -583,6 +583,28 @@ def test_connectivities_key_kernel(self, adata: AnnData): assert T_cr is not adata.obsp[key] np.testing.assert_array_equal(T_cr.A, adata.obsp[key]) + @pytest.mark.parametrize("cluster_pair", [("Granule immature", "Granule mature"), ("nIPC", "Neuroblast")]) + @pytest.mark.parametrize("graph_key", ["distances", "connectivities"]) + def test_cbc(self, adata: AnnData, cluster_pair: Tuple[str, str], graph_key: str): + cluster_key = "clusters" + rep = "X_pca" + source, target = cluster_pair + + vk = cr.kernels.VelocityKernel(adata) + vk.compute_transition_matrix() + + ck = cr.kernels.ConnectivityKernel(adata) + ck.compute_transition_matrix() + combined_kernel = 0.8 * vk + 0.2 * ck + + cbc_vk = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep) + np.testing.assert_almost_equal(cbc_vk, adata.uns["cbc"][f"{source}-{target}-{graph_key}-vk"]) + + cbc_combined_kernel = combined_kernel.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep) + np.testing.assert_almost_equal( + cbc_combined_kernel, adata.uns["cbc"][f"{source}-{target}-{graph_key}-0.8vk+0.2ck"] + ) + class TestVelocityKernelReadData: @pytest.mark.parametrize("attr", ["layers", "obsm"])