Skip to content

Commit

Permalink
Improve SNF
Browse files Browse the repository at this point in the history
This improves mu.tl.snf interface as well as functionality
by providing more arguments as well as fixing implementation bugs.

Arguments are now adjusted to better match mu.pp.neighbors().

The method now records connectivities and distances.
The way these sparse matrices are computed might change in future.
The location of the method might change as well (mu.tl -> mu.pp).
  • Loading branch information
gtca committed Oct 8, 2021
1 parent e40ca58 commit 6c37dc5
Showing 1 changed file with 155 additions and 21 deletions.
176 changes: 155 additions & 21 deletions muon/_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pandas as pd
from scipy.sparse import issparse
from scipy.sparse import issparse, csr_matrix
import scanpy as sc
import h5py
from natsort import natsorted
Expand All @@ -17,10 +17,14 @@

from scanpy._compat import Literal
from scanpy import logging
from scanpy.tools._utils import _choose_representation
from scanpy.neighbors import _compute_connectivities_umap

from typing import Union, Optional, List, Iterable, Mapping, Sequence, Type, Any
from typing import Union, Optional, List, Iterable, Mapping, Sequence, Type, Any, Dict
from types import MappingProxyType

from .preproc import _sparse_csr_fast_knn

try:
from louvain.VertexPartition import MutableVertexPartition as LouvainMutableVertexPartition
except ImportError:
Expand Down Expand Up @@ -659,7 +663,16 @@ def mofa(
#


def snf(mdata: MuData, key: str = "connectivities", k: int = 20, iterations: int = 20):
def snf(
mdata: MuData,
n_neighbors: int = 20,
neighbor_keys: Optional[Union[str, Dict[str, Optional[str]]]] = None,
key_added: Optional[str] = None,
n_iterations: int = 20,
sigma: float = 0.5,
eps: float = np.finfo(np.float64).eps,
copy: bool = False,
) -> Optional[MuData]:
"""
Similarity network fusion (SNF)
Expand All @@ -672,66 +685,189 @@ def snf(mdata: MuData, key: str = "connectivities", k: int = 20, iterations: int
----------
mdata:
MuData object
key: str (default: 'connectivities')
Key in .obsp to be used as SNF algorithm input.
Has to exist in all modalities.
k: int (default: 20)
n_neighbors: int (default: 20)
Number of neighbours to be used in the K-nearest neighbours step
iterations: int (default: 20)
neighbor_keys: Keys in .uns where per-modality neighborhood information is stored. Defaults to ``"neighbors"``.
If set as a dictionary, only the modalities present in ``neighbor_keys`` will be used for multimodal nearest neighbor search.
If set as a string, has to exist in all modalities.
key_added: If not specified, the multimodal neighbors data is stored in ``.uns["neighbors"]``, distances and
connectivities are stored in ``.obsp["distances"]`` and ``.obsp["connectivities"]``, respectively. If specified, the
neighbors data is added to ``.uns[key_added]``, distances are stored in ``.obsp[key_added + "_distances"]`` and
connectivities in ``.obsp[key_added + "_connectivities"]``.
n_iterations: int (default: 20)
Number of iterations for the diffusion process
sigma: float (default: 0.5)
Variance for the local model when calculating affinity matrices
eps: Small number to avoid numerical errors.
copy: Return a copy instead of writing to ``mdata``.
"""
import scipy.stats as stats

mdata = mdata.copy() if copy else mdata

if neighbor_keys is None:
modalities = mdata.mod.keys()
neighbor_keys = {}
elif isinstance(neighbor_keys, str):
modalities = mdata.mod.keys()
neighbor_keys = {m: neighbor_keys for m in modalities}
else:
modalities = neighbor_keys.keys()

mod_neighbors = np.empty((len(modalities),), dtype=np.uint16)
mod_reps = {}
reps = {}
mod_n_pcs = {}
neighbors_params = {}
for i, mod in enumerate(modalities):
nkey = neighbor_keys.get(mod, "neighbors")

try:
nparams = mdata.mod[mod].uns[nkey]
except KeyError:
raise ValueError(
f'Did not find .uns["{nkey}"] for modality "{mod}". Run `sc.pp.neighbors` on all modalities first.'
)

use_rep = nparams["params"].get("use_rep", None)
n_pcs = nparams["params"].get("n_pcs", None)
mod_neighbors[i] = nparams["params"].get("n_neighbors", 0)

neighbors_params[mod] = nparams
reps[mod] = _choose_representation(mdata.mod[mod], use_rep, n_pcs)
mod_reps[mod] = (
use_rep if use_rep is not None else -1
) # otherwise this is not saved to h5mu
mod_n_pcs[mod] = n_pcs if n_pcs is not None else -1

def _affinity_matrix(dist, k, sigma):
"""
Compute the affinity matrix for a distance matrix
Reference implementation can be found in the SNFtool R package:
https://github.com/cran/SNFtool/blob/master/R/affinityMatrix.R
PARAMETERS
----------
mdata:
MuData object
k: int (default: 20)
Number of neighbours to be used in the K-nearest neighbours step
sigma: float (default: 0.5)
Variance for the local model when calculating affinity matrices
"""
dist = (dist + dist.T) / 2
if issparse(dist):
dist.setdiag(0)
dist.eliminate_zeros()
else:
np.fill_diagonal(dist, 0)

# FIXME: adopt for sparse matrices
if issparse(dist):
logging.warning(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Using dense distance matrix when computing affinity matrix..."
)
dist = dist.todense()
sorted_columns = np.apply_along_axis(np.sort, 1, dist)

def finite_mean(x, *args, **kwargs):
return np.mean(x[~np.isinf(x)], *args, **kwargs)

means = np.apply_along_axis(finite_mean, 1, sorted_columns[:, 1 : k + 1]) + eps
sig = np.add.outer(means, means) / 3 + dist / 3 + eps
densities = stats.norm(0, sigma * sig).pdf(dist)

w = (densities + densities.T) / 2
return w

wall = []
for mod in mdata.mod:
# TODO: check the key exists in every modality
wall.append(mdata.mod[mod].obsp[key])
nkey = neighbor_keys.get(mod, "neighbors")
# the key has to exists in every modality
if nkey not in mdata.mod[mod].uns:
raise ValueError(f"The key '{nkey}' is missing from the .uns slot of modality '{mod}'")
mod_distances = mdata.mod[mod].obsp[neighbors_params[mod]["distances_key"]]
w = _affinity_matrix(mod_distances, k=n_neighbors, sigma=sigma)
wall.append(w)

def _normalize(x):
row_sum_mdiag = x.sum(axis=1) - x.diagonal()
row_sum_mdiag[row_sum_mdiag == 0] = 1
x = x / (2 * row_sum_mdiag)
x = x / (2 * row_sum_mdiag[:, None])
np.fill_diagonal(x, 0.5)
x = (x + x.T) / 2
return x

def _dominateset(x, k=20):
def _zero(arr):
if k >= len(arr):
raise ValueError(f"'n_neighbors' seems to be too high.")
arr = arr.copy()
arr[np.argsort(arr)[: (len(arr) - k)]] = 0
return arr

x = np.apply_along_axis(_zero, 0, wall[0])
x = np.apply_along_axis(_zero, 0, x)
return x / x.sum(axis=1)

for i in range(len(wall)):
wall[i] = _normalize(wall[i])

new = []
for i in range(len(wall)):
new.append(_dominateset(wall[i], k))
new.append(_dominateset(wall[i], n_neighbors))

nextW = [None] * len(wall)

logging.info(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Starting {iterations} iterations..."
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Starting {n_iterations} iterations..."
)
for ti in range(iterations):
for ti in range(n_iterations):
for j in range(len(wall)):
sumWJ = np.zeros(shape=(wall[j].shape[0], wall[j].shape[1]))
for ki in range(len(wall)):
if ki != j:
sumWJ = sumWJ + wall[ki]
nextW[j] = new[j] * (sumWJ / (len(wall) - 1)) * new[j].T
nextW[j] = np.dot(np.dot(new[j], (sumWJ / (len(wall) - 1))), new[j].T)
for j in range(len(wall)):
wall[j] = _normalize(nextW[j])
logging.info(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Done: iteration {ti} of {iterations}."
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Done: iteration {ti} of {n_iterations}."
)

# Sum diffused matrices
w = np.sum(wall, axis=0)
w = w / len(wall)
w = _normalize(w)

mdata.obsp[key] = w
# Keep n neighbours
logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Calculating distances...")
neighbordistances = _sparse_csr_fast_knn(csr_matrix(0.5 - w), n_neighbors)

# TODO: use _compute_connectivities_umap?
logging.info(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Calculating connectivities...")
connectivities = _sparse_csr_fast_knn(csr_matrix(w), n_neighbors)

if key_added is None:
key_added = "neighbors"
conns_key = "connectivities"
dists_key = "distances"
else:
conns_key = key_added + "_connectivities"
dists_key = key_added + "_distances"
neighbors_dict = {"connectivities_key": conns_key, "distances_key": dists_key}
neighbors_dict["params"] = {
"n_neighbors": n_neighbors,
"eps": eps,
"use_rep": mod_reps,
"n_pcs": mod_n_pcs,
"method": "snf",
}
mdata.obsp[conns_key] = connectivities
mdata.obsp[dists_key] = neighbordistances
mdata.uns[key_added] = neighbors_dict

return mdata if copy else None


#
Expand Down Expand Up @@ -1115,9 +1251,7 @@ def umap(
try:
neighbors = mdata.uns[neighbors_key]
except KeyError:
raise ValueError(
f'Did not find .uns["{neighbors_key}"]. Run `muon.pp.weighted_neighbors` first.'
)
raise ValueError(f'Did not find .uns["{neighbors_key}"]. Run `muon.pp.neighbors` first.')

from scanpy.tools._utils import _choose_representation
from copy import deepcopy
Expand Down

0 comments on commit 6c37dc5

Please sign in to comment.