Skip to content

Commit

Permalink
fix some metadata routing
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Feb 11, 2025
1 parent 3ed46ce commit af22078
Showing 1 changed file with 137 additions and 5 deletions.
142 changes: 137 additions & 5 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from copy import deepcopy
from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union

try:
Expand All @@ -18,10 +19,13 @@
from sklearn.pipeline import Pipeline as _Pipeline
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
from sklearn.utils import Bunch
from sklearn.utils._tags import get_tags, Tags # pylint: disable=protected-access
from sklearn.utils.metadata_routing import (
_routing_enabled, # pylint: disable=protected-access
)
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
process_routing,
)
from sklearn.utils.metaestimators import available_if
Expand Down Expand Up @@ -482,7 +486,9 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
"All input rows were filtered out! Model is not fitted!"
)
else:
fit_params_last_step = routed_params[self.steps[-1][0]]
fit_params_last_step = routed_params[
self._non_post_processing_steps()[-1][0]
]
self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"])

return self
Expand Down Expand Up @@ -552,7 +558,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
elif is_empty(iter_input):
logger.warning("All input rows were filtered out! Model is not fitted!")
else:
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = routed_params[
self._non_post_processing_steps()[-1][0]
]
if hasattr(last_step, "fit_transform"):
iter_input = last_step.fit_transform(
iter_input, iter_label, **last_step_params["fit_transform"]
Expand Down Expand Up @@ -615,7 +623,8 @@ def predict(self, X: Any, **params: Any) -> Any:
elif hasattr(self._final_estimator, "predict"):
if _routing_enabled():
iter_input = self._final_estimator.predict(
iter_input, **routed_params[self._final_estimator].predict
iter_input,
**routed_params[self._non_post_processing_steps()[-1][0]].predict,
)
else:
iter_input = self._final_estimator.predict(iter_input, **params)
Expand Down Expand Up @@ -665,7 +674,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
X, y, routed_params
) # pylint: disable=invalid-name

params_last_step = routed_params[self.steps[-1][0]]
params_last_step = routed_params[self._non_post_processing_steps()[-1][0]]
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator == "passthrough":
y_pred = iter_input
Expand Down Expand Up @@ -724,7 +733,10 @@ def predict_proba(self, X: Any, **params: Any) -> Any:
elif hasattr(self._final_estimator, "predict_proba"):
if _routing_enabled():
iter_input = self._final_estimator.predict_proba(
iter_input, **routed_params[self.steps[-1][0]].predict_proba
iter_input,
**routed_params[
self._non_post_processing_steps()[-1][0]
].predict_proba,
)
else:
iter_input = self._final_estimator.predict_proba(iter_input, **params)
Expand Down Expand Up @@ -854,3 +866,123 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]:
if hasattr(last_step, "classes_"):
return last_step.classes_
raise ValueError("Last step has no classes_ attribute.")

def __sklearn_tags__(self) -> Tags:
"""Return the sklearn tags."""
tags = super().__sklearn_tags__()

if not self.steps:
return tags

try:
if self.steps[0][1] is not None and self.steps[0][1] != "passthrough":
tags.input_tags.pairwise = get_tags(
self.steps[0][1]
).input_tags.pairwise
# WARNING: the sparse tag can be incorrect.
# Some Pipelines accepting sparse data are wrongly tagged sparse=False.
# For example Pipeline([PCA(), estimator]) accepts sparse data
# even if the estimator doesn't as PCA outputs a dense array.
tags.input_tags.sparse = all(
get_tags(step).input_tags.sparse
for name, step in self.steps
if step != "passthrough"
)
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
pass

try:
# Only the _final_estimator is changed from the original implementation
if (
self._final_estimator is not None
and self._final_estimator != "passthrough"
):
last_step_tags = get_tags(self._final_estimator)
tags.estimator_type = last_step_tags.estimator_type
tags.target_tags.multi_output = last_step_tags.target_tags.multi_output
tags.classifier_tags = deepcopy(last_step_tags.classifier_tags)
tags.regressor_tags = deepcopy(last_step_tags.regressor_tags)
tags.transformer_tags = deepcopy(last_step_tags.transformer_tags)
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
pass

return tags

def get_metadata_routing(self) -> MetadataRouter:
"""Get metadata routing of this object.
Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.
Returns
-------
MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)

# first we add all steps except the last one
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
method_mapping = MethodMapping()
# fit, fit_predict, and fit_transform call fit_transform if it
# exists, or else fit and transform
if hasattr(trans, "fit_transform"):
(
method_mapping.add(caller="fit", callee="fit_transform")
.add(caller="fit_transform", callee="fit_transform")
.add(caller="fit_predict", callee="fit_transform")
)
else:
(
method_mapping.add(caller="fit", callee="fit")
.add(caller="fit", callee="transform")
.add(caller="fit_transform", callee="fit")
.add(caller="fit_transform", callee="transform")
.add(caller="fit_predict", callee="fit")
.add(caller="fit_predict", callee="transform")
)

(
method_mapping.add(caller="predict", callee="transform")
.add(caller="predict", callee="transform")
.add(caller="predict_proba", callee="transform")
.add(caller="decision_function", callee="transform")
.add(caller="predict_log_proba", callee="transform")
.add(caller="transform", callee="transform")
.add(caller="inverse_transform", callee="inverse_transform")
.add(caller="score", callee="transform")
)

router.add(method_mapping=method_mapping, **{name: trans})

final_name, final_est = self._non_post_processing_steps()[-1]
if final_est is None or final_est == "passthrough":
return router

# then we add the last step
method_mapping = MethodMapping()
if hasattr(final_est, "fit_transform"):
method_mapping.add(caller="fit_transform", callee="fit_transform")
else:
method_mapping.add(caller="fit", callee="fit").add(
caller="fit", callee="transform"
)
(
method_mapping.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="fit_predict", callee="fit_predict")
.add(caller="predict_proba", callee="predict_proba")
.add(caller="decision_function", callee="decision_function")
.add(caller="predict_log_proba", callee="predict_log_proba")
.add(caller="transform", callee="transform")
.add(caller="inverse_transform", callee="inverse_transform")
.add(caller="score", callee="score")
)

router.add(method_mapping=method_mapping, **{final_name: final_est})
return router

0 comments on commit af22078

Please sign in to comment.