Skip to content

Commit

Permalink
FIX ColumnKernelizer for newer sklearn versions (#54)
Browse files Browse the repository at this point in the history
* FIX add kwargs to ColumnKernelizer._hstack for newer sklearn versions

* FIX more fixes to deal with sklearn changes

* FIX,TST fixes for ColumnTransformerNoStack
  • Loading branch information
mvdoc authored Mar 29, 2024
1 parent a7bb851 commit a5c226d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
34 changes: 22 additions & 12 deletions himalaya/kernel_ridge/_kernelizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from sklearn.compose import ColumnTransformer
from sklearn.compose import make_column_selector # noqa
from packaging import version

import sklearn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import (
ColumnTransformer,
make_column_selector, # noqa
)
from sklearn.pipeline import _name_estimators, make_pipeline
from sklearn.utils.validation import check_is_fitted
from sklearn.pipeline import make_pipeline, _name_estimators

from ..backend import get_backend
from ..backend import force_cpu_backend
from ..validation import check_array
from ..validation import _get_string_dtype

from ._kernels import pairwise_kernels
from ._kernels import PAIRWISE_KERNEL_FUNCTIONS
from ..backend import force_cpu_backend, get_backend
from ..validation import _get_string_dtype, check_array
from ._kernels import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels


class Kernelizer(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -336,7 +337,7 @@ def _iter(self, fitted=False, *args, **kwargs):

yield (name, trans, column, weight)

def _hstack(self, Xs):
def _hstack(self, Xs, **kwargs):
"""Stack the kernels.
In ColumnTransformer, this methods stacks Xs horizontally.
Expand All @@ -363,8 +364,17 @@ def get_X_fit(self):
"""
check_is_fitted(self)

if version.parse(sklearn.__version__) < version.parse("1.4"):
iter_kwargs = {"replace_strings": True}
else:
iter_kwargs = {
"column_as_labels": False,
"skip_drop": True,
"skip_empty_columns": True
}

Xs = []
for (_, trans, _, _) in self._iter(fitted=True, replace_strings=True):
for (_, trans, _, _) in self._iter(fitted=True, **iter_kwargs):
if hasattr(trans, "get_X_fit"):
X = trans.get_X_fit()
else:
Expand Down
2 changes: 1 addition & 1 deletion himalaya/ridge/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class ColumnTransformerNoStack(ColumnTransformer):
(2, 4, 4)
"""

def _hstack(self, Xs):
def _hstack(self, Xs, **kwargs):
"""Do *not* stack the feature spaces.
In ColumnTransformer, this methods stacks Xs horizontally.
Expand Down
2 changes: 1 addition & 1 deletion himalaya/ridge/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_column_transformer_remainder(backend):
backend = set_backend(backend)
X = np.random.randn(10, 5)

ct = ColumnTransformerNoStack([("name", "passthrough", [])],
ct = ColumnTransformerNoStack([("name", "passthrough", slice(0, 0))],
remainder="passthrough")
Xt = ct.fit_transform(X)
assert len(Xt) == 2
Expand Down

0 comments on commit a5c226d

Please sign in to comment.