Skip to content

Commit

Permalink
estimators: add NearestNeighborsRetrieverTanimoto
Browse files Browse the repository at this point in the history
    - Add new class that performs k-nearest neighbor searches using
      Tanimoto similarity. The implementation uses sparse dot product
      making the algorithm 2-3x faster than RDKit's BulkTanimotoSimilarity
    - Add notebook illustrating NearestNeighborsRetrieverTanimoto for
      dataset similarity analysis, like train/test set comaparison.
  • Loading branch information
JochenSiegWork committed Feb 10, 2025
1 parent 0b079fa commit e9b1423
Show file tree
Hide file tree
Showing 4 changed files with 1,066 additions and 3 deletions.
208 changes: 207 additions & 1 deletion molpipeline/estimators/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@

from __future__ import annotations

import multiprocessing
from typing import Any, Callable, Literal, Sequence, Union

from joblib import Parallel, delayed

try:
from typing import Self
except ImportError:
from typing_extensions import Self


import numpy as np
import numpy.typing as npt
from scipy import sparse
from sklearn.neighbors import NearestNeighbors

from molpipeline.utils.kernel import tanimoto_similarity_sparse
from molpipeline.utils.value_checks import get_length

__all__ = ["NamedNearestNeighbors"]
Expand Down Expand Up @@ -202,3 +205,206 @@ def fit_predict(
"""
self.fit(X, y)
return self.predict(X, return_distance=return_distance, n_neighbors=n_neighbors)


class NearestNeighborsRetrieverTanimoto: # pylint: disable=too-few-public-methods
"""k-nearest neighbors between data sets using Tanimoto similarity.
This class uses the Tanimoto similarity to find the k-nearest neighbors of a query set in a target set.
The full similarity matrix is computed and reduced to the k-nearest neighbors. A dot-product based
algorithm is used, which is faster than using the RDKit native Tanimoto function.
For handling larger datasets, the computation can be batched to reduce memory usage. In addition,
the batches can be processed in parallel using joblib.
"""

def __init__(
self,
target_fingerprints: sparse.csr_matrix,
k: int | None = None,
batch_size: int = 1000,
n_jobs: int = 1,
):
"""Initialize NearestNeighborsRetrieverTanimoto.
Parameters
----------
target_fingerprints: sparse.csr_matrix
Fingerprints of target molecules. Must be a binary sparse matrix.
"""
self.target_fingerprints = target_fingerprints
if k is None:
self.k = self.target_fingerprints.shape[0]
else:
self.k = k
self.batch_size = batch_size
if n_jobs == -1:
self.n_jobs = multiprocessing.cpu_count()
else:
self.n_jobs = n_jobs
if self.k == 1:
self.knn_reduce_function = self._reduce_k_equals_1
elif self.k < self.target_fingerprints.shape[0]:
self.knn_reduce_function = self._reduce_k_greater_1_less_n
else:
self.knn_reduce_function = self._reduct_to_indices_k_equals_n

@staticmethod
def _reduce_k_equals_1(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=1 nearest neighbors.
Uses argmax to find the index of the nearest neighbor in the target fingerprints.
This function has therefore O(n) time complexity.
Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.
Returns
-------
npt.NDArray[np.int64]
Indices of the query's nearest neighbors in the target fingerprints.
"""
topk_indices = np.argmax(similarity_matrix, axis=1)
topk_similarities = np.take_along_axis(
similarity_matrix, topk_indices.reshape(-1, 1), axis=1
).squeeze()
return topk_indices, topk_similarities

def _reduce_k_greater_1_less_n(
self,
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k>1 and k<n nearest neighbors.
Uses argpartition to find the k-nearest neighbors in the target fingerprints, which uses a linear
partial sort algorithm. The top k hits must be sorted afterward to get the k-nearest neighbors
in descending order. This function has therefore O(n + k log k) time complexity.
The indices are sorted descending by similarity.
Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.
Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the query's k-nearest neighbors in the target fingerprints and
the corresponding similarities.
"""
# Get the indices of the k-nearest neighbors. argpartition returns them unsorted.
topk_indices = np.argpartition(similarity_matrix, kth=-self.k, axis=1)[
:, -self.k :
]
topk_similarities = np.take_along_axis(similarity_matrix, topk_indices, axis=1)
# sort the topk_indices descending by similarity
topk_indices_sorted = np.take_along_axis(
topk_indices,
np.fliplr(topk_similarities.argsort(axis=1, kind="stable")),
axis=1,
)
topk_similarities_sorted = np.take_along_axis(
similarity_matrix, topk_indices_sorted, axis=1
)
return topk_indices_sorted, topk_similarities_sorted

@staticmethod
def _reduct_to_indices_k_equals_n(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=n nearest neighbors.
Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.
Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the query's k-nearest neighbors in the target fingerprints and
the corresponding similarities.
"""
indices = np.fliplr(similarity_matrix.argsort(axis=1, kind="stable"))
similarities = np.take_along_axis(similarity_matrix, indices, axis=1)
return indices, similarities

def _process_batch(
self, query_batch: sparse.csr_matrix
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Process a batch of query fingerprints.
Parameters
----------
query_batch: sparse.csr_matrix
Batch of query fingerprints.
Returns
-------
tuple
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
# compute full similarity matrix for the query batch
similarity_mat_chunk = tanimoto_similarity_sparse(
query_batch, self.target_fingerprints
)

# reduce the similarity matrix to the k nearest neighbors
return self.knn_reduce_function(similarity_mat_chunk)

def predict(
self, query_fingerprints: sparse.csr_matrix
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Predict the k-nearest neighbors of the query fingerprints.
Parameters
----------
query_fingerprints: sparse.csr_matrix
Query fingerprints.
Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
if query_fingerprints.shape[1] != self.target_fingerprints.shape[1]:
raise ValueError(
"The number of features in the query fingerprints does not match the number of features in the target fingerprints."
)
if self.n_jobs > 1:
# parallel execution
with Parallel(n_jobs=self.n_jobs) as parallel:
# the parallelization is not optimal: the self.target_fingerprints (and query_fingerprints) are copied to each child process worker
# -> joblib does some behind the scenes mmapping but copying the full matrices is probably a memory bottleneck.
# If Python removes the GIL this here would be a good use case for threading with zero copies.
res = parallel(
delayed(self._process_batch)(
query_fingerprints[i : i + self.batch_size, :]
)
for i in range(0, query_fingerprints.shape[0], self.batch_size)
)
result_indices_tmp, result_similarities_tmp = zip(*res)
result_indices = np.concatenate(result_indices_tmp)
result_similarities = np.concatenate(result_similarities_tmp)
else:
# single process execution
result_shape = (
(query_fingerprints.shape[0], self.k)
if self.k > 1
else (query_fingerprints.shape[0],)
)
result_indices = np.full(result_shape, -1, dtype=np.int64)
result_similarities = np.full(result_shape, np.nan, dtype=np.float64)
for i in range(0, query_fingerprints.shape[0], self.batch_size):
query_batch = query_fingerprints[i : i + self.batch_size, :]
(
result_indices[i : i + self.batch_size],
result_similarities[i : i + self.batch_size],
) = self._process_batch(query_batch)

return result_indices, result_similarities
8 changes: 6 additions & 2 deletions molpipeline/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def tanimoto_similarity_sparse(
Matrix of similarity values between instances of A (rows/first dim) , and instances of B (columns/second dim).
"""
intersection = matrix_a.dot(matrix_b.transpose()).toarray()
norm_1 = np.array(matrix_a.multiply(matrix_a).sum(axis=1))
norm_2 = np.array(matrix_b.multiply(matrix_b).sum(axis=1))
norm_1 = np.array(matrix_a.sum(axis=1))
if matrix_a is matrix_b:
# avoid calculating the same norm twice
norm_2 = norm_1
else:
norm_2 = np.array(matrix_b.sum(axis=1))
union = norm_1 + norm_2.T - intersection
# avoid division by zero https://stackoverflow.com/a/37977222
return np.divide(
Expand Down
Loading

0 comments on commit e9b1423

Please sign in to comment.