From 11d814db7426f3ab833bab0ccaa6a63303462746 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sun, 31 Dec 2023 10:51:54 +0800 Subject: [PATCH] 23.06 reobust-ish fix --- cu_cat/_table_vectorizer.py | 52 +++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/cu_cat/_table_vectorizer.py b/cu_cat/_table_vectorizer.py index b05c799f9..2c42ba56e 100644 --- a/cu_cat/_table_vectorizer.py +++ b/cu_cat/_table_vectorizer.py @@ -499,6 +499,7 @@ def _auto_cast(self, X: pd.DataFrame) -> pd.DataFrame: for i in obj_col: X[i]=X[i].replace('nan',np.nan).fillna('0o0o0') X[i]=X[i].str.rjust(4,'0') + X[i]=X[i].str.replace('.', 'dot', regex=False) for col in X.columns: # Convert pandas' NaN value (pd.NA) to numpy NaN value (np.nan) # because the former tends to raise all kind of issues when dealing @@ -799,21 +800,24 @@ def get_feature_names_out(self, input_features=None) -> List[str]: typing.List[str] Feature names. """ - # if 'cudf' not in self.Xt_ and not deps.cudf: - # if parse_version(sklearn_version) < parse_version("1.0"): - # ct_feature_names = super().get_feature_names() - # else: - # ct_feature_names = super().get_feature_names_out() - # else: - # if parse_version(sklearn_version) > parse_version("1.0"): - try: - ct_feature_names = super().get_feature_names_out() - except: - pass - try: - ct_feature_names = super().get_feature_names() - except: - pass + if 'cudf' not in self.Xt_ and not deps.cudf: + if parse_version(sklearn_version) > parse_version("1.0"): + ct_feature_names = super().get_feature_names() + else: + ct_feature_names = super().get_feature_names_out() + else: + if parse_version(sklearn_version) < parse_version("1.0"): + ct_feature_names = super().get_feature_names_out() + else: + ct_feature_names = super().get_feature_names() + # try: + # ct_feature_names = super().get_feature_names_out() + # except: + # pass + # try: + # ct_feature_names = super().get_feature_names() + # except: + # pass all_trans_feature_names = [] for name, trans, cols, _ in self._iter(fitted=True): @@ -824,14 +828,16 @@ def get_feature_names_out(self, input_features=None) -> List[str]: cols = self.columns_.to_list() all_trans_feature_names.extend(cols) continue - try: - trans_feature_names = super().get_feature_names_out() - except: - pass - try: - trans_feature_names = super().get_feature_names() - except: - pass + if 'cudf' not in self.Xt_ and not deps.cudf: + if parse_version(sklearn_version) > parse_version("1.0"): + trans_feature_names = super().get_feature_names() + else: + trans_feature_names = super().get_feature_names_out() + else: + if parse_version(sklearn_version) < parse_version("1.0"): + trans_feature_names = super().get_feature_names_out() + else: + trans_feature_names = super().get_feature_names() all_trans_feature_names.extend(trans_feature_names) if len(ct_feature_names) != len(all_trans_feature_names):