Skip to content

Commit

Permalink
set cv=2 and add error filtering to reproduce error
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Feb 11, 2025
1 parent 7499cf0 commit 3387c7a
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier

from molpipeline import ErrorFilter, Pipeline
from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.any2mol import AutoToMol, SmilesToMol
from molpipeline.mol2any import MolToMorganFP, MolToRDKitPhysChem, MolToSmiles
from molpipeline.mol2mol import (
Expand Down Expand Up @@ -384,16 +384,24 @@ def test_calibrated_classifier(self) -> None:
smi2mol = SmilesToMol()
mol2morgan = MolToMorganFP(radius=FP_RADIUS, n_bits=FP_SIZE)
d_tree = DecisionTreeClassifier()
error_filter = ErrorFilter(filter_everything=True)
s_pipeline = Pipeline(
[
("smi2mol", smi2mol),
("morgan", mol2morgan),
("error_filter", error_filter),
("decision_tree", d_tree),
(
"error_replacer",
PostPredictionWrapper(
FilterReinserter.from_error_filter(error_filter, None)
),
),
]
)
calibrated_pipeline = CalibratedClassifierCV(s_pipeline)
calibrated_pipeline = CalibratedClassifierCV(s_pipeline, cv=2)
calibrated_pipeline.fit(TEST_SMILES, CONTAINS_OX)
predicted_value_array = s_pipeline.predict(TEST_SMILES)
predicted_value_array = calibrated_pipeline.predict(TEST_SMILES)
for pred_val, true_val in zip(predicted_value_array, CONTAINS_OX):
self.assertEqual(pred_val, true_val)

Expand Down

0 comments on commit 3387c7a

Please sign in to comment.