From a5c226d28e6a4e64dc465d92518428f2898b8b1f Mon Sep 17 00:00:00 2001 From: Matteo Visconti di Oleggio Castello Date: Fri, 29 Mar 2024 09:22:32 -0700 Subject: [PATCH] FIX ColumnKernelizer for newer sklearn versions (#54) * FIX add kwargs to ColumnKernelizer._hstack for newer sklearn versions * FIX more fixes to deal with sklearn changes * FIX,TST fixes for ColumnTransformerNoStack --- himalaya/kernel_ridge/_kernelizer.py | 34 ++++++++++++++++++---------- himalaya/ridge/_column.py | 2 +- himalaya/ridge/tests/test_column.py | 2 +- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/himalaya/kernel_ridge/_kernelizer.py b/himalaya/kernel_ridge/_kernelizer.py index 02b3d62..9ed75f2 100644 --- a/himalaya/kernel_ridge/_kernelizer.py +++ b/himalaya/kernel_ridge/_kernelizer.py @@ -1,16 +1,17 @@ -from sklearn.compose import ColumnTransformer -from sklearn.compose import make_column_selector # noqa +from packaging import version + +import sklearn from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.compose import ( + ColumnTransformer, + make_column_selector, # noqa +) +from sklearn.pipeline import _name_estimators, make_pipeline from sklearn.utils.validation import check_is_fitted -from sklearn.pipeline import make_pipeline, _name_estimators - -from ..backend import get_backend -from ..backend import force_cpu_backend -from ..validation import check_array -from ..validation import _get_string_dtype -from ._kernels import pairwise_kernels -from ._kernels import PAIRWISE_KERNEL_FUNCTIONS +from ..backend import force_cpu_backend, get_backend +from ..validation import _get_string_dtype, check_array +from ._kernels import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels class Kernelizer(TransformerMixin, BaseEstimator): @@ -336,7 +337,7 @@ def _iter(self, fitted=False, *args, **kwargs): yield (name, trans, column, weight) - def _hstack(self, Xs): + def _hstack(self, Xs, **kwargs): """Stack the kernels. In ColumnTransformer, this methods stacks Xs horizontally. @@ -363,8 +364,17 @@ def get_X_fit(self): """ check_is_fitted(self) + if version.parse(sklearn.__version__) < version.parse("1.4"): + iter_kwargs = {"replace_strings": True} + else: + iter_kwargs = { + "column_as_labels": False, + "skip_drop": True, + "skip_empty_columns": True + } + Xs = [] - for (_, trans, _, _) in self._iter(fitted=True, replace_strings=True): + for (_, trans, _, _) in self._iter(fitted=True, **iter_kwargs): if hasattr(trans, "get_X_fit"): X = trans.get_X_fit() else: diff --git a/himalaya/ridge/_column.py b/himalaya/ridge/_column.py index cac64fe..f907a3d 100644 --- a/himalaya/ridge/_column.py +++ b/himalaya/ridge/_column.py @@ -129,7 +129,7 @@ class ColumnTransformerNoStack(ColumnTransformer): (2, 4, 4) """ - def _hstack(self, Xs): + def _hstack(self, Xs, **kwargs): """Do *not* stack the feature spaces. In ColumnTransformer, this methods stacks Xs horizontally. diff --git a/himalaya/ridge/tests/test_column.py b/himalaya/ridge/tests/test_column.py index fc798a4..6e84442 100644 --- a/himalaya/ridge/tests/test_column.py +++ b/himalaya/ridge/tests/test_column.py @@ -40,7 +40,7 @@ def test_column_transformer_remainder(backend): backend = set_backend(backend) X = np.random.randn(10, 5) - ct = ColumnTransformerNoStack([("name", "passthrough", [])], + ct = ColumnTransformerNoStack([("name", "passthrough", slice(0, 0))], remainder="passthrough") Xt = ct.fit_transform(X) assert len(Xt) == 2