From 6c37dc5c186bbae1d2cc3f4ad2d9bca92cbbafe1 Mon Sep 17 00:00:00 2001 From: Danila Bredikhin Date: Fri, 8 Oct 2021 16:52:46 +0200 Subject: [PATCH] Improve SNF This improves mu.tl.snf interface as well as functionality by providing more arguments as well as fixing implementation bugs. Arguments are now adjusted to better match mu.pp.neighbors(). The method now records connectivities and distances. The way these sparse matrices are computed might change in future. The location of the method might change as well (mu.tl -> mu.pp). --- muon/_core/tools.py | 176 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 155 insertions(+), 21 deletions(-) diff --git a/muon/_core/tools.py b/muon/_core/tools.py index f44d928..bc868f3 100644 --- a/muon/_core/tools.py +++ b/muon/_core/tools.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd -from scipy.sparse import issparse +from scipy.sparse import issparse, csr_matrix import scanpy as sc import h5py from natsort import natsorted @@ -17,10 +17,14 @@ from scanpy._compat import Literal from scanpy import logging +from scanpy.tools._utils import _choose_representation +from scanpy.neighbors import _compute_connectivities_umap -from typing import Union, Optional, List, Iterable, Mapping, Sequence, Type, Any +from typing import Union, Optional, List, Iterable, Mapping, Sequence, Type, Any, Dict from types import MappingProxyType +from .preproc import _sparse_csr_fast_knn + try: from louvain.VertexPartition import MutableVertexPartition as LouvainMutableVertexPartition except ImportError: @@ -659,7 +663,16 @@ def mofa( # -def snf(mdata: MuData, key: str = "connectivities", k: int = 20, iterations: int = 20): +def snf( + mdata: MuData, + n_neighbors: int = 20, + neighbor_keys: Optional[Union[str, Dict[str, Optional[str]]]] = None, + key_added: Optional[str] = None, + n_iterations: int = 20, + sigma: float = 0.5, + eps: float = np.finfo(np.float64).eps, + copy: bool = False, +) -> Optional[MuData]: """ Similarity network fusion (SNF) @@ -672,33 +685,129 @@ def snf(mdata: MuData, key: str = "connectivities", k: int = 20, iterations: int ---------- mdata: MuData object - key: str (default: 'connectivities') - Key in .obsp to be used as SNF algorithm input. - Has to exist in all modalities. - k: int (default: 20) + n_neighbors: int (default: 20) Number of neighbours to be used in the K-nearest neighbours step - iterations: int (default: 20) + neighbor_keys: Keys in .uns where per-modality neighborhood information is stored. Defaults to ``"neighbors"``. + If set as a dictionary, only the modalities present in ``neighbor_keys`` will be used for multimodal nearest neighbor search. + If set as a string, has to exist in all modalities. + key_added: If not specified, the multimodal neighbors data is stored in ``.uns["neighbors"]``, distances and + connectivities are stored in ``.obsp["distances"]`` and ``.obsp["connectivities"]``, respectively. If specified, the + neighbors data is added to ``.uns[key_added]``, distances are stored in ``.obsp[key_added + "_distances"]`` and + connectivities in ``.obsp[key_added + "_connectivities"]``. + n_iterations: int (default: 20) Number of iterations for the diffusion process + sigma: float (default: 0.5) + Variance for the local model when calculating affinity matrices + eps: Small number to avoid numerical errors. + copy: Return a copy instead of writing to ``mdata``. """ + import scipy.stats as stats + + mdata = mdata.copy() if copy else mdata + + if neighbor_keys is None: + modalities = mdata.mod.keys() + neighbor_keys = {} + elif isinstance(neighbor_keys, str): + modalities = mdata.mod.keys() + neighbor_keys = {m: neighbor_keys for m in modalities} + else: + modalities = neighbor_keys.keys() + + mod_neighbors = np.empty((len(modalities),), dtype=np.uint16) + mod_reps = {} + reps = {} + mod_n_pcs = {} + neighbors_params = {} + for i, mod in enumerate(modalities): + nkey = neighbor_keys.get(mod, "neighbors") + + try: + nparams = mdata.mod[mod].uns[nkey] + except KeyError: + raise ValueError( + f'Did not find .uns["{nkey}"] for modality "{mod}". Run `sc.pp.neighbors` on all modalities first.' + ) + + use_rep = nparams["params"].get("use_rep", None) + n_pcs = nparams["params"].get("n_pcs", None) + mod_neighbors[i] = nparams["params"].get("n_neighbors", 0) + + neighbors_params[mod] = nparams + reps[mod] = _choose_representation(mdata.mod[mod], use_rep, n_pcs) + mod_reps[mod] = ( + use_rep if use_rep is not None else -1 + ) # otherwise this is not saved to h5mu + mod_n_pcs[mod] = n_pcs if n_pcs is not None else -1 + + def _affinity_matrix(dist, k, sigma): + """ + Compute the affinity matrix for a distance matrix + + Reference implementation can be found in the SNFtool R package: + https://github.com/cran/SNFtool/blob/master/R/affinityMatrix.R + + PARAMETERS + ---------- + mdata: + MuData object + k: int (default: 20) + Number of neighbours to be used in the K-nearest neighbours step + sigma: float (default: 0.5) + Variance for the local model when calculating affinity matrices + """ + dist = (dist + dist.T) / 2 + if issparse(dist): + dist.setdiag(0) + dist.eliminate_zeros() + else: + np.fill_diagonal(dist, 0) + + # FIXME: adopt for sparse matrices + if issparse(dist): + logging.warning( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Using dense distance matrix when computing affinity matrix..." + ) + dist = dist.todense() + sorted_columns = np.apply_along_axis(np.sort, 1, dist) + + def finite_mean(x, *args, **kwargs): + return np.mean(x[~np.isinf(x)], *args, **kwargs) + + means = np.apply_along_axis(finite_mean, 1, sorted_columns[:, 1 : k + 1]) + eps + sig = np.add.outer(means, means) / 3 + dist / 3 + eps + densities = stats.norm(0, sigma * sig).pdf(dist) + + w = (densities + densities.T) / 2 + return w + wall = [] for mod in mdata.mod: - # TODO: check the key exists in every modality - wall.append(mdata.mod[mod].obsp[key]) + nkey = neighbor_keys.get(mod, "neighbors") + # the key has to exists in every modality + if nkey not in mdata.mod[mod].uns: + raise ValueError(f"The key '{nkey}' is missing from the .uns slot of modality '{mod}'") + mod_distances = mdata.mod[mod].obsp[neighbors_params[mod]["distances_key"]] + w = _affinity_matrix(mod_distances, k=n_neighbors, sigma=sigma) + wall.append(w) def _normalize(x): row_sum_mdiag = x.sum(axis=1) - x.diagonal() row_sum_mdiag[row_sum_mdiag == 0] = 1 - x = x / (2 * row_sum_mdiag) + x = x / (2 * row_sum_mdiag[:, None]) np.fill_diagonal(x, 0.5) x = (x + x.T) / 2 return x def _dominateset(x, k=20): def _zero(arr): + if k >= len(arr): + raise ValueError(f"'n_neighbors' seems to be too high.") + arr = arr.copy() arr[np.argsort(arr)[: (len(arr) - k)]] = 0 return arr - x = np.apply_along_axis(_zero, 0, wall[0]) + x = np.apply_along_axis(_zero, 0, x) return x / x.sum(axis=1) for i in range(len(wall)): @@ -706,24 +815,24 @@ def _zero(arr): new = [] for i in range(len(wall)): - new.append(_dominateset(wall[i], k)) + new.append(_dominateset(wall[i], n_neighbors)) nextW = [None] * len(wall) logging.info( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Starting {iterations} iterations..." + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Starting {n_iterations} iterations..." ) - for ti in range(iterations): + for ti in range(n_iterations): for j in range(len(wall)): sumWJ = np.zeros(shape=(wall[j].shape[0], wall[j].shape[1])) for ki in range(len(wall)): if ki != j: sumWJ = sumWJ + wall[ki] - nextW[j] = new[j] * (sumWJ / (len(wall) - 1)) * new[j].T + nextW[j] = np.dot(np.dot(new[j], (sumWJ / (len(wall) - 1))), new[j].T) for j in range(len(wall)): wall[j] = _normalize(nextW[j]) logging.info( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Done: iteration {ti} of {iterations}." + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Done: iteration {ti} of {n_iterations}." ) # Sum diffused matrices @@ -731,7 +840,34 @@ def _zero(arr): w = w / len(wall) w = _normalize(w) - mdata.obsp[key] = w + # Keep n neighbours + logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Calculating distances...") + neighbordistances = _sparse_csr_fast_knn(csr_matrix(0.5 - w), n_neighbors) + + # TODO: use _compute_connectivities_umap? + logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Calculating connectivities...") + connectivities = _sparse_csr_fast_knn(csr_matrix(w), n_neighbors) + + if key_added is None: + key_added = "neighbors" + conns_key = "connectivities" + dists_key = "distances" + else: + conns_key = key_added + "_connectivities" + dists_key = key_added + "_distances" + neighbors_dict = {"connectivities_key": conns_key, "distances_key": dists_key} + neighbors_dict["params"] = { + "n_neighbors": n_neighbors, + "eps": eps, + "use_rep": mod_reps, + "n_pcs": mod_n_pcs, + "method": "snf", + } + mdata.obsp[conns_key] = connectivities + mdata.obsp[dists_key] = neighbordistances + mdata.uns[key_added] = neighbors_dict + + return mdata if copy else None # @@ -1115,9 +1251,7 @@ def umap( try: neighbors = mdata.uns[neighbors_key] except KeyError: - raise ValueError( - f'Did not find .uns["{neighbors_key}"]. Run `muon.pp.weighted_neighbors` first.' - ) + raise ValueError(f'Did not find .uns["{neighbors_key}"]. Run `muon.pp.neighbors` first.') from scanpy.tools._utils import _choose_representation from copy import deepcopy