From 64ce295d0c169a4a62dc5e860da22826842ca009 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 7 Sep 2023 18:21:01 +0200 Subject: [PATCH] FIX: Make pd import local in DataFrameTransformer (#1017) DataFrameTransformer requires pandas but we don't want to have a dependency on pandas. Therefore, the pandas import should be local. Back when I wrote this, I thought it would be sufficient to import pandas at a class level, but that is incorrect. Now, pandas is important inside the method bodies. For most skorch users, this should make no difference. But it will allow skorch users who do use something from helpers.py to avoid having to install pandas. --- skorch/helper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/skorch/helper.py b/skorch/helper.py index 90f68050a..cf4d60366 100644 --- a/skorch/helper.py +++ b/skorch/helper.py @@ -370,8 +370,6 @@ class DataFrameTransformer(BaseEstimator, TransformerMixin): contains 1 column. """ - import pandas as pd - def __init__( self, treat_int_as_categorical=False, @@ -399,6 +397,8 @@ def _check_dtypes(self, df): If a wrong dtype is found. """ + import pandas as pd + if 'X' in df: raise ValueError( "DataFrame contains a column named 'X', which clashes " @@ -408,7 +408,7 @@ def _check_dtypes(self, df): wrong_dtypes = [] for col, dtype in zip(df, df.dtypes): - if isinstance(dtype, self.pd.api.types.CategoricalDtype): + if isinstance(dtype, pd.api.types.CategoricalDtype): continue if np.issubdtype(dtype, np.integer): continue @@ -447,6 +447,8 @@ def transform(self, df): respective column names as keys. """ + import pandas as pd + self._check_dtypes(df) X_dict = {} @@ -455,7 +457,7 @@ def transform(self, df): for col, dtype in zip(df, df.dtypes): X_col = df[col] - if isinstance(dtype, self.pd.api.types.CategoricalDtype): + if isinstance(dtype, pd.api.types.CategoricalDtype): x = X_col.cat.codes.values if self.int_dtype is not None: x = x.astype(self.int_dtype)