Skip to content

Commit

Permalink
improved nan handling
Browse files Browse the repository at this point in the history
  • Loading branch information
deadsoul44 committed Nov 21, 2024
1 parent 8c669bc commit 492afa7
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions python-package/python/perpetual/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ def convert_input_frame(
if categorical_features_:
for i in categorical_features_:
categories, inversed = np.unique(X_[:, i].astype(str), return_inverse=True)
if categories[-1] == "nan":
categories = list(categories)
else:
categories = list(categories)
categories.append("nan")
inversed = inversed.astype(np.float32)

categories = list(categories)
if "nan" in categories:
categories.remove("nan")
categories.insert(0, "nan")

inversed = inversed + 1.0

if len(categories) > max_cat:
cat_to_num.append(i)
Expand All @@ -133,7 +134,7 @@ def convert_input_frame(

feature_name = features_[i]
cat_mapping[feature_name] = categories
ind_nan = len(categories) - 1
ind_nan = len(categories)
inversed[inversed == ind_nan] = np.nan
X_[:, i] = inversed

Expand Down Expand Up @@ -178,10 +179,11 @@ def transform_input_frame(X, cat_mapping) -> Tuple[List[str], np.ndarray, int, i
if cat_mapping:
for feature_name, categories in cat_mapping.items():
feature_index = features_.index(feature_name)
x_enc = np.searchsorted(
categories, X_[:, feature_index].astype(str)
).astype(np.float32)
ind_nan = len(categories) - 1
cats = categories.copy()
cats.remove("nan")
x_enc = np.searchsorted(cats, X_[:, feature_index].astype(str))
x_enc = x_enc + 1.0
ind_nan = len(categories)
x_enc[x_enc == ind_nan] = np.nan
X_[:, feature_index] = x_enc

Expand Down

0 comments on commit 492afa7

Please sign in to comment.