Skip to content

Commit

Permalink
first part Christians comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Feb 13, 2025
1 parent 4a3700f commit 753c9ab
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 50 deletions.
101 changes: 79 additions & 22 deletions molpipeline/estimators/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

import multiprocessing
from typing import Any, Callable, Literal, Sequence, Union

from joblib import Parallel, delayed
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator

from molpipeline.utils.multi_proc import check_available_cores

try:
from typing import Self
Expand Down Expand Up @@ -207,8 +210,8 @@ def fit_predict(
return self.predict(X, return_distance=return_distance, n_neighbors=n_neighbors)


class NearestNeighborsRetrieverTanimoto: # pylint: disable=too-few-public-methods
"""k-nearest neighbors between data sets using Tanimoto similarity.
class TanimotoKNN(BaseEstimator): # pylint: disable=too-few-public-methods
"""k-nearest neighbors (KNN) between data sets using Tanimoto similarity.
This class uses the Tanimoto similarity to find the k-nearest neighbors of a query set in a target set.
The full similarity matrix is computed and reduced to the k-nearest neighbors. A dot-product based
Expand All @@ -218,42 +221,85 @@ class NearestNeighborsRetrieverTanimoto: # pylint: disable=too-few-public-metho
the batches can be processed in parallel using joblib.
"""

target_indices_mapping_: npt.NDArray[np.int64] | None

def __init__(
self,
target_fingerprints: sparse.csr_matrix,
k: int | None = None,
*,
k: int | None,
batch_size: int = 1000,
n_jobs: int = 1,
):
"""Initialize NearestNeighborsRetrieverTanimoto.
"""Initialize TanimotoKNN.
Parameters
----------
target_fingerprints: sparse.csr_matrix
Fingerprints of target molecules. Must be a binary sparse matrix.
k: int, optional (default=None)
k: int | None
Number of nearest neighbors to find. If None, all neighbors are returned.
batch_size: int, optional (default=1000)
Size of the batches for parallel processing.
n_jobs: int, optional (default=1)
Number of parallel jobs to run for neighbors search.
"""
self.target_fingerprints = target_fingerprints
if k is None:
self.k = self.target_fingerprints.shape[0]
else:
self.k = k
self.target_fingerprints: csr_matrix | None = None
self.k = k
self.batch_size = batch_size
if n_jobs == -1:
self.n_jobs = multiprocessing.cpu_count()
else:
self.n_jobs = n_jobs
self.n_jobs = check_available_cores(n_jobs)
self.knn_reduce_function: (
Callable[
[npt.NDArray[np.float64]],
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]],
]
| None
) = None

def fit(
self,
X: sparse.csr_matrix, # pylint: disable=invalid-name
y: Sequence[Any] | None = None, # pylint: disable=invalid-name
) -> Self:
"""Fit the estimator using X as target fingerprint data set.
Parameters
----------
X : sparse.csr_matrix
The target fingerprint data set. By calling `predict`, searches are performed
against this target data set.
y : Sequence[Any]
Target values. Here values are used as returned nearest neighbors.
Must have the same length as X.
Will be stored as the learned_names_ attribute as npt.NDArray[Any].
Returns
-------
Self
The instance itself.
Raises
------
ValueError
If the input arrays have different lengths or do not have a shape nor len attribute.
"""
if y is None:
y = list(range(X.shape[0]))
if X.shape[0] != get_length(y):
raise ValueError("X and y must have the same length.")

if self.k is None:
# set k to the number of target fingerprints if k is None
self.k = X.shape[0]

# determine the recude function dependent on the value of k
if self.k == 1:
self.knn_reduce_function = self._reduce_k_equals_1
elif self.k < self.target_fingerprints.shape[0]:
elif self.k < X.shape[0]:
self.knn_reduce_function = self._reduce_k_greater_1_less_n
else:
self.knn_reduce_function = self._reduct_to_indices_k_equals_n
self.knn_reduce_function = self._reduce_k_equals_n

self.target_indices_mapping_ = np.array(y)
self.target_fingerprints = X
return self

@staticmethod
def _reduce_k_equals_1(
Expand Down Expand Up @@ -320,7 +366,7 @@ def _reduce_k_greater_1_less_n(
return topk_indices_sorted, topk_similarities_sorted

@staticmethod
def _reduct_to_indices_k_equals_n(
def _reduce_k_equals_n(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=n nearest neighbors.
Expand Down Expand Up @@ -360,11 +406,16 @@ def _process_batch(
query_batch, self.target_fingerprints
)

if self.knn_reduce_function is None:
raise AssertionError(
"The knn_reduce_function has not been set. This should happen in the fit function."
)
# reduce the similarity matrix to the k nearest neighbors
return self.knn_reduce_function(similarity_mat_chunk)

def predict(
self, query_fingerprints: sparse.csr_matrix
self,
query_fingerprints: sparse.csr_matrix,
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Predict the k-nearest neighbors of the query fingerprints.
Expand All @@ -378,6 +429,12 @@ def predict(
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
if self.target_fingerprints is None:
raise ValueError("The model has not been fitted yet.")
if self.k is None:
raise AssertionError(
"The number of neighbors k has not been set. This should happen in the fit function."
)
if query_fingerprints.shape[1] != self.target_fingerprints.shape[1]:
raise ValueError(
"The number of features in the query fingerprints does not match the number of features in the target fingerprints."
Expand Down
2 changes: 1 addition & 1 deletion molpipeline/utils/multi_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def check_available_cores(n_requested_cores: int) -> int:
)
return n_available_cores
if n_requested_cores < 0:
return n_available_cores
return n_available_cores + 1 + n_requested_cores

return n_requested_cores
82 changes: 55 additions & 27 deletions tests/test_estimators/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.any2mol import SmilesToMol
from molpipeline.estimators import NamedNearestNeighbors, TanimotoToTraining
from molpipeline.estimators.nearest_neighbor import NearestNeighborsRetrieverTanimoto
from molpipeline.estimators.nearest_neighbor import TanimotoKNN
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.kernel import tanimoto_distance_sparse

Expand Down Expand Up @@ -222,8 +222,8 @@ def test_fit_and_predict_invalid_with_distance(self) -> None:
)


class TestNearestNeighborsRetrieverTanimoto(TestCase):
"""Test nearest neighbors retriever with tanimoto."""
class TestTanimotoKNN(TestCase):
"""Test TanimotoKNN estimator."""

example_fingerprints: csr_matrix

Expand All @@ -243,16 +243,16 @@ def test_k_equals_1(self) -> None:
target_fps = self.example_fingerprints
query_fps = self.example_fingerprints

retriever = NearestNeighborsRetrieverTanimoto(target_fps, k=1)
indices, similarities = retriever.predict(query_fps)
knn = TanimotoKNN(k=1)
knn.fit(target_fps)
indices, similarities = knn.predict(query_fps)
self.assertTrue(np.array_equal(indices, np.array([0, 1, 2, 3])))
self.assertTrue(np.allclose(similarities, np.array([1, 1, 1, 1])))

# test parallel
retriever = NearestNeighborsRetrieverTanimoto(
target_fps, k=1, n_jobs=2, batch_size=2
)
indices, similarities = retriever.predict(query_fps)
knn = TanimotoKNN(k=1, n_jobs=2, batch_size=2)
knn.fit(target_fps)
indices, similarities = knn.predict(query_fps)
self.assertTrue(np.array_equal(indices, np.array([0, 1, 2, 3])))
self.assertTrue(np.allclose(similarities, np.array([1, 1, 1, 1])))

Expand All @@ -261,18 +261,18 @@ def test_k_greater_1_less_n(self) -> None:
target_fps = self.example_fingerprints
query_fps = self.example_fingerprints

retriever = NearestNeighborsRetrieverTanimoto(target_fps, k=2)
indices, similarities = retriever.predict(query_fps)
knn = TanimotoKNN(k=2)
knn.fit(target_fps)
indices, similarities = knn.predict(query_fps)
self.assertTrue(
np.array_equal(indices, np.array([[0, 1], [1, 0], [2, 3], [3, 2]]))
)
self.assertTrue(np.allclose(similarities, TWO_NN_SIMILARITIES))

# test parallel
retriever = NearestNeighborsRetrieverTanimoto(
target_fps, k=2, n_jobs=2, batch_size=2
)
indices, similarities = retriever.predict(query_fps)
knn = TanimotoKNN(k=2, n_jobs=2, batch_size=2)
knn.fit(target_fps)
indices, similarities = knn.predict(query_fps)
self.assertTrue(
np.array_equal(indices, np.array([[0, 1], [1, 0], [2, 3], [3, 2]]))
)
Expand All @@ -283,8 +283,9 @@ def test_k_equals_n(self) -> None:
target_fps = self.example_fingerprints
query_fps = self.example_fingerprints

retriever = NearestNeighborsRetrieverTanimoto(target_fps, k=target_fps.shape[0])
indices, similarities = retriever.predict(query_fps)
knn = TanimotoKNN(k=target_fps.shape[0])
knn.fit(target_fps)
indices, similarities = knn.predict(query_fps)
self.assertTrue(
np.array_equal(
indices,
Expand All @@ -294,10 +295,9 @@ def test_k_equals_n(self) -> None:
self.assertTrue(np.allclose(similarities, FOUR_NN_SIMILARITIES))

# test parallel
retriever = NearestNeighborsRetrieverTanimoto(
target_fps, k=target_fps.shape[0], n_jobs=2, batch_size=2
)
indices, similarities = retriever.predict(query_fps)
knn = TanimotoKNN(k=target_fps.shape[0], n_jobs=2, batch_size=2)
knn.fit(target_fps)
indices, similarities = knn.predict(query_fps)
self.assertTrue(
np.array_equal(
indices,
Expand All @@ -306,9 +306,37 @@ def test_k_equals_n(self) -> None:
)
self.assertTrue(np.allclose(similarities, FOUR_NN_SIMILARITIES))

# [
# [1.0, 3 / 14, 0.0, 0.0],
# [1.0, 3 / 14, 0.038461538461538464, 0.0],
# [1.0, 4 / 9, 0.0, 0.0],
# [1.0, 4 / 9, 0.038461538461538464, 0.0],
# ]
def test_pipeline(self) -> None:
"""Test TanimotoKNN in a pipeline."""
# test normal pipeline
pipeline = Pipeline(
[
("mol", SmilesToMol()),
("fingerprint", MolToMorganFP()),
("knn", TanimotoKNN(k=1)),
]
)
pipeline.fit(TEST_SMILES)
indices, similarities = pipeline.predict(TEST_SMILES)
self.assertTrue(np.array_equal(indices, np.array([0, 1, 2, 3])))
self.assertTrue(np.allclose(similarities, np.array([1, 1, 1, 1])))

# test pipeline with failing smiles
test_smiles = [
"c1ccccc1",
"c1cc(-C(=O)O)ccc1",
"I am a failing smiles :)",
"CCCCCCN",
"CCCCCCO",
]
pipeline = Pipeline(
[
("mol", SmilesToMol()),
("error_filter", ErrorFilter(filter_everything=True)),
("fingerprint", MolToMorganFP()),
("knn", TanimotoKNN(k=1)),
]
)
pipeline.fit(test_smiles)
indices, similarities = pipeline.predict(test_smiles)
todo assert right result

0 comments on commit 753c9ab

Please sign in to comment.