Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset similarity #122

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 270 additions & 1 deletion molpipeline/estimators/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@

from typing import Any, Callable, Literal, Sequence, Union

from joblib import Parallel, delayed
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator

from molpipeline.utils.multi_proc import check_available_cores

try:
from typing import Self
except ImportError:
from typing_extensions import Self


import numpy as np
import numpy.typing as npt
from scipy import sparse
from sklearn.neighbors import NearestNeighbors

from molpipeline.utils.kernel import tanimoto_similarity_sparse
from molpipeline.utils.value_checks import get_length

__all__ = ["NamedNearestNeighbors"]
Expand Down Expand Up @@ -202,3 +208,266 @@ def fit_predict(
"""
self.fit(X, y)
return self.predict(X, return_distance=return_distance, n_neighbors=n_neighbors)


class TanimotoKNN(BaseEstimator): # pylint: disable=too-few-public-methods
"""k-nearest neighbors (KNN) between data sets using Tanimoto similarity.

This class uses the Tanimoto similarity to find the k-nearest neighbors of a query set in a target set.
The full similarity matrix is computed and reduced to the k-nearest neighbors. A dot-product based
algorithm is used, which is faster than using the RDKit native Tanimoto function.

For handling larger datasets, the computation can be batched to reduce memory usage. In addition,
the batches can be processed in parallel using joblib.
"""

target_indices_mapping_: npt.NDArray[np.int64] | None

def __init__(
self,
*,
k: int | None,
batch_size: int = 1000,
n_jobs: int = 1,
):
"""Initialize TanimotoKNN.

Parameters
----------
k: int | None
Number of nearest neighbors to find. If None, all neighbors are returned.
batch_size: int, optional (default=1000)
Size of the batches for parallel processing.
n_jobs: int, optional (default=1)
Number of parallel jobs to run for neighbors search.
"""
self.target_fingerprints: csr_matrix | None = None
self.k = k
self.batch_size = batch_size
self.n_jobs = check_available_cores(n_jobs)
self.knn_reduce_function: (
Callable[
[npt.NDArray[np.float64]],
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]],
]
| None
) = None

def fit(
self,
X: sparse.csr_matrix, # pylint: disable=invalid-name
y: Sequence[Any] | None = None, # pylint: disable=invalid-name
) -> Self:
"""Fit the estimator using X as target fingerprint data set.

Parameters
----------
X : sparse.csr_matrix
The target fingerprint data set. By calling `predict`, searches are performed
against this target data set.
y : Sequence[Any]
Target values. Here values are used as returned nearest neighbors.
Must have the same length as X.
Will be stored as the learned_names_ attribute as npt.NDArray[Any].

Returns
-------
Self
The instance itself.

Raises
------
ValueError
If the input arrays have different lengths or do not have a shape nor len attribute.
"""
if y is None:
y = list(range(X.shape[0]))
if X.shape[0] != get_length(y):
raise ValueError("X and y must have the same length.")

if self.k is None:
# set k to the number of target fingerprints if k is None
self.k = X.shape[0]

# determine the recude function dependent on the value of k
if self.k == 1:
self.knn_reduce_function = self._reduce_k_equals_1
elif self.k < X.shape[0]:
self.knn_reduce_function = self._reduce_k_greater_1_less_n
else:
self.knn_reduce_function = self._reduce_k_equals_n

self.target_indices_mapping_ = np.array(y)
self.target_fingerprints = X
return self

@staticmethod
def _reduce_k_equals_1(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=1 nearest neighbors.

Uses argmax to find the index of the nearest neighbor in the target fingerprints.
This function has therefore O(n) time complexity.

Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.

Returns
-------
npt.NDArray[np.int64]
Indices of the query's nearest neighbors in the target fingerprints.
"""
topk_indices = np.argmax(similarity_matrix, axis=1)
topk_similarities = np.take_along_axis(
similarity_matrix, topk_indices.reshape(-1, 1), axis=1
).squeeze()
return topk_indices, topk_similarities

def _reduce_k_greater_1_less_n(
self,
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k>1 and k<n nearest neighbors.

Uses argpartition to find the k-nearest neighbors in the target fingerprints, which uses a linear
partial sort algorithm. The top k hits must be sorted afterward to get the k-nearest neighbors
in descending order. This function has therefore O(n + k log k) time complexity.

The indices are sorted descending by similarity.

Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.

Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the query's k-nearest neighbors in the target fingerprints and
the corresponding similarities.
"""
# Get the indices of the k-nearest neighbors. argpartition returns them unsorted.
topk_indices = np.argpartition(similarity_matrix, kth=-self.k, axis=1)[
:, -self.k :
]
topk_similarities = np.take_along_axis(similarity_matrix, topk_indices, axis=1)
# sort the topk_indices descending by similarity
topk_indices_sorted = np.take_along_axis(
topk_indices,
np.fliplr(topk_similarities.argsort(axis=1, kind="stable")),
axis=1,
)
topk_similarities_sorted = np.take_along_axis(
similarity_matrix, topk_indices_sorted, axis=1
)
return topk_indices_sorted, topk_similarities_sorted

@staticmethod
def _reduce_k_equals_n(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=n nearest neighbors.

Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.

Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the query's k-nearest neighbors in the target fingerprints and
the corresponding similarities.
"""
indices = np.fliplr(similarity_matrix.argsort(axis=1, kind="stable"))
similarities = np.take_along_axis(similarity_matrix, indices, axis=1)
return indices, similarities

def _process_batch(
self, query_batch: sparse.csr_matrix
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Process a batch of query fingerprints.

Parameters
----------
query_batch: sparse.csr_matrix
Batch of query fingerprints.

Returns
-------
tuple
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
# compute full similarity matrix for the query batch
similarity_mat_chunk = tanimoto_similarity_sparse(
query_batch, self.target_fingerprints
)

if self.knn_reduce_function is None:
raise AssertionError(
"The knn_reduce_function has not been set. This should happen in the fit function."
)
# reduce the similarity matrix to the k nearest neighbors
return self.knn_reduce_function(similarity_mat_chunk)

def predict(
self,
query_fingerprints: sparse.csr_matrix,
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Predict the k-nearest neighbors of the query fingerprints.

Parameters
----------
query_fingerprints: sparse.csr_matrix
Query fingerprints.

Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
if self.target_fingerprints is None:
raise ValueError("The model has not been fitted yet.")
if self.k is None:
raise AssertionError(
"The number of neighbors k has not been set. This should happen in the fit function."
)
if query_fingerprints.shape[1] != self.target_fingerprints.shape[1]:
raise ValueError(
"The number of features in the query fingerprints does not match the number of features in the target fingerprints."
)
if self.n_jobs > 1:
# parallel execution
with Parallel(n_jobs=self.n_jobs) as parallel:
# the parallelization is not optimal: the self.target_fingerprints (and query_fingerprints) are copied to each child process worker
# -> joblib does some behind the scenes mmapping but copying the full matrices is probably a memory bottleneck.
# If Python removes the GIL this here would be a good use case for threading with zero copies.
res = parallel(
delayed(self._process_batch)(
query_fingerprints[i : i + self.batch_size, :]
)
for i in range(0, query_fingerprints.shape[0], self.batch_size)
)
result_indices_tmp, result_similarities_tmp = zip(*res)
result_indices = np.concatenate(result_indices_tmp)
result_similarities = np.concatenate(result_similarities_tmp)
else:
# single process execution
result_shape = (
(query_fingerprints.shape[0], self.k)
if self.k > 1
else (query_fingerprints.shape[0],)
)
result_indices = np.full(result_shape, -1, dtype=np.int64)
result_similarities = np.full(result_shape, np.nan, dtype=np.float64)
for i in range(0, query_fingerprints.shape[0], self.batch_size):
query_batch = query_fingerprints[i : i + self.batch_size, :]
(
result_indices[i : i + self.batch_size],
result_similarities[i : i + self.batch_size],
) = self._process_batch(query_batch)

return result_indices, result_similarities
8 changes: 6 additions & 2 deletions molpipeline/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def tanimoto_similarity_sparse(
Matrix of similarity values between instances of A (rows/first dim) , and instances of B (columns/second dim).
"""
intersection = matrix_a.dot(matrix_b.transpose()).toarray()
norm_1 = np.array(matrix_a.multiply(matrix_a).sum(axis=1))
norm_2 = np.array(matrix_b.multiply(matrix_b).sum(axis=1))
norm_1 = np.array(matrix_a.sum(axis=1))
if matrix_a is matrix_b:
# avoid calculating the same norm twice
norm_2 = norm_1
else:
norm_2 = np.array(matrix_b.sum(axis=1))
union = norm_1 + norm_2.T - intersection
# avoid division by zero https://stackoverflow.com/a/37977222
return np.divide(
Expand Down
2 changes: 1 addition & 1 deletion molpipeline/utils/multi_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def check_available_cores(n_requested_cores: int) -> int:
)
return n_available_cores
if n_requested_cores < 0:
return n_available_cores
return n_available_cores + 1 + n_requested_cores

return n_requested_cores
Loading
Loading