Skip to content

Commit

Permalink
Update _get_empiricial_velocity_field
Browse files Browse the repository at this point in the history
* Refactor definition of `empirical_velo`.
  • Loading branch information
WeilerP committed Mar 3, 2024
1 parent 3a148ec commit 71bacc3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/cellrank/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,16 @@ def _get_empirical_velocity_field(
obs_ids = np.arange(0, self.adata.n_obs)
graph = self.adata.obsp[graph_key]
features = self.adata.obsm[rep]
empirical_velo = []
empirical_velo = np.empty(shape=(len(boundary_ids), features.shape[1]))

for boundary_id in boundary_ids:
for idx, boundary_id in enumerate(boundary_ids):
row = graph[boundary_id, :].toarray().squeeze()
obs_mask = row.astype(bool) & target_obs_mask
neighbors = obs_ids[obs_mask]
weights = row[obs_mask]

empirical_velo.append(
np.sum(weights.reshape(-1, 1) * (features[neighbors, :] - features[boundary_id, :]), axis=0)
empirical_velo[idx, :] = np.sum(
weights.reshape(-1, 1) * (features[neighbors, :] - features[boundary_id, :]), axis=0
)

empirical_velo = np.array(empirical_velo)
Expand Down

0 comments on commit 71bacc3

Please sign in to comment.