Skip to content

Commit

Permalink
add test for ChEMPROP Model, which fails
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Feb 12, 2025
1 parent 2fc34e6 commit 5ee8c70
Showing 1 changed file with 47 additions and 14 deletions.
61 changes: 47 additions & 14 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
from lightning import pytorch as pl
from sklearn.base import clone
from sklearn.calibration import CalibratedClassifierCV

from molpipeline.any2mol import SmilesToMol
from molpipeline.error_handling import ErrorFilter, FilterReinserter
Expand Down Expand Up @@ -317,28 +318,34 @@ def test_prediction(self) -> None:
class TestClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for classification."""

def test_prediction(self) -> None:
"""Test the prediction of the classification model."""

molecule_net_bbbp_df = pd.read_csv(
def setUp(self) -> None:
"""Set up repeated variables."""
self.molecule_net_bbbp_df = pd.read_csv(
TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100
)

def test_prediction(self) -> None:
"""Test the prediction of the classification model."""
classification_model = get_classification_pipeline()
classification_model.fit(
molecule_net_bbbp_df["smiles"].tolist(),
molecule_net_bbbp_df["p_np"].to_numpy(),
self.molecule_net_bbbp_df["smiles"].tolist(),
self.molecule_net_bbbp_df["p_np"].to_numpy(),
)
pred = classification_model.predict(
self.molecule_net_bbbp_df["smiles"].tolist()
)
pred = classification_model.predict(molecule_net_bbbp_df["smiles"].tolist())
proba = classification_model.predict_proba(
molecule_net_bbbp_df["smiles"].tolist()
self.molecule_net_bbbp_df["smiles"].tolist()
)
self.assertEqual(len(pred), len(molecule_net_bbbp_df))
self.assertEqual(len(pred), len(self.molecule_net_bbbp_df))
self.assertEqual(proba.shape[1], 2)
self.assertEqual(proba.shape[0], len(molecule_net_bbbp_df))
self.assertEqual(proba.shape[0], len(self.molecule_net_bbbp_df))

model_copy = joblib_dump_load(classification_model)
pred_copy = model_copy.predict(molecule_net_bbbp_df["smiles"].tolist())
proba_copy = model_copy.predict_proba(molecule_net_bbbp_df["smiles"].tolist())
pred_copy = model_copy.predict(self.molecule_net_bbbp_df["smiles"].tolist())
proba_copy = model_copy.predict_proba(
self.molecule_net_bbbp_df["smiles"].tolist()
)

nan_indices = np.isnan(pred)
self.assertListEqual(nan_indices.tolist(), np.isnan(pred_copy).tolist())
Expand All @@ -349,14 +356,40 @@ def test_prediction(self) -> None:

# Test single prediction, this was causing an error before
single_mol_pred = classification_model.predict(
[molecule_net_bbbp_df["smiles"].iloc[0]]
[self.molecule_net_bbbp_df["smiles"].iloc[0]]
)
self.assertEqual(single_mol_pred.shape, (1,))
single_mol_proba = classification_model.predict_proba(
[molecule_net_bbbp_df["smiles"].iloc[0]]
[self.molecule_net_bbbp_df["smiles"].iloc[0]]
)
self.assertEqual(single_mol_proba.shape, (1, 2))

def test_calibrated_classifier(self) -> None:
"""Test if the pipeline can be used with a CalibratedClassifierCV."""
calibrated_pipeline = CalibratedClassifierCV(
get_classification_pipeline(), cv=2, ensemble=True, method="isotonic"
)
calibrated_pipeline.fit(
self.molecule_net_bbbp_df["smiles"].tolist(),
self.molecule_net_bbbp_df["p_np"].to_numpy(),
)
predicted_value_array = calibrated_pipeline.predict(
self.molecule_net_bbbp_df["smiles"].tolist()
)
predicted_proba_array = calibrated_pipeline.predict_proba(
self.molecule_net_bbbp_df["smiles"].tolist()
)
self.assertIsInstance(predicted_value_array, np.ndarray)
self.assertIsInstance(predicted_proba_array, np.ndarray)
self.assertEqual(
predicted_value_array.shape,
(len(self.molecule_net_bbbp_df["smiles"].tolist()),),
)
self.assertEqual(
predicted_proba_array.shape,
(len(self.molecule_net_bbbp_df["smiles"].tolist()), 2),
)


class TestMulticlassClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for multiclass classification."""
Expand Down

0 comments on commit 5ee8c70

Please sign in to comment.