diff --git a/src/cellflow/preprocessing/_wknn.py b/src/cellflow/preprocessing/_wknn.py index 222a9dcf..4cede888 100644 --- a/src/cellflow/preprocessing/_wknn.py +++ b/src/cellflow/preprocessing/_wknn.py @@ -134,15 +134,16 @@ def transfer_labels( ) wknn = ref_adata.uns[wknn_key] + labels_onehot = pd.get_dummies(ref_adata.obs[label_key].astype(str)) + labels_mat = sparse.csr_matrix((labels_onehot > 0).astype(int)) - scores = pd.DataFrame( - wknn @ pd.get_dummies(ref_adata.obs[label_key]), - columns=pd.get_dummies(ref_adata.obs[label_key]).columns, - index=query_adata.obs_names, - ) + scores = wknn @ labels_mat + label_indices = np.array(scores.argmax(1)).flatten() + max_scores = scores.max(1).toarray().flatten() - query_adata.obs[f"{label_key}_transfer"] = scores.idxmax(1) - query_adata.obs[f"{label_key}_transfer_score"] = scores.max(1) + query_adata.obs[f"{label_key}_transfer"] = labels_onehot.columns[label_indices] + query_adata.obs[f"{label_key}_transfer"] = query_adata.obs[f"{label_key}_transfer"].astype("category") + query_adata.obs[f"{label_key}_transfer_score"] = max_scores if copy: return query_adata