From 492afa73fb2d08f9ce269f8097266ab24e23d1a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?mutlu=20=C5=9Fim=C5=9Fek?= Date: Thu, 21 Nov 2024 10:27:05 +0300 Subject: [PATCH] improved nan handling --- python-package/python/perpetual/utils.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/python-package/python/perpetual/utils.py b/python-package/python/perpetual/utils.py index 400fd7b..df51ff5 100644 --- a/python-package/python/perpetual/utils.py +++ b/python-package/python/perpetual/utils.py @@ -118,12 +118,13 @@ def convert_input_frame( if categorical_features_: for i in categorical_features_: categories, inversed = np.unique(X_[:, i].astype(str), return_inverse=True) - if categories[-1] == "nan": - categories = list(categories) - else: - categories = list(categories) - categories.append("nan") - inversed = inversed.astype(np.float32) + + categories = list(categories) + if "nan" in categories: + categories.remove("nan") + categories.insert(0, "nan") + + inversed = inversed + 1.0 if len(categories) > max_cat: cat_to_num.append(i) @@ -133,7 +134,7 @@ def convert_input_frame( feature_name = features_[i] cat_mapping[feature_name] = categories - ind_nan = len(categories) - 1 + ind_nan = len(categories) inversed[inversed == ind_nan] = np.nan X_[:, i] = inversed @@ -178,10 +179,11 @@ def transform_input_frame(X, cat_mapping) -> Tuple[List[str], np.ndarray, int, i if cat_mapping: for feature_name, categories in cat_mapping.items(): feature_index = features_.index(feature_name) - x_enc = np.searchsorted( - categories, X_[:, feature_index].astype(str) - ).astype(np.float32) - ind_nan = len(categories) - 1 + cats = categories.copy() + cats.remove("nan") + x_enc = np.searchsorted(cats, X_[:, feature_index].astype(str)) + x_enc = x_enc + 1.0 + ind_nan = len(categories) x_enc[x_enc == ind_nan] = np.nan X_[:, feature_index] = x_enc