-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto encoding for categorical data during inference. #11088
Comments
@david-cortes May I ask how the xgboost/R-package/R/xgb.DMatrix.R Line 572 in f4f3bd4
The |
@trivialfis In that function, the DMatrix is set in the line right before Regarding the feature: since the idea is to have this feature in different interfaces, how would it work behind the scenes? Would be ideal if the categorical encodings could get saved in the booster and be used in plots/trees-to-tables/jsons/etc. (#9927). Better yet if it's a standardized C-level attribute so that the encodings could survive transfers from one interface to another. I see some potential difficulties though:
|
It needs to be kept until the next
We will store the levels in the booster as you suggested. Things will be handled in C++, we might allow users to optionally disable the encoder for performance reasons (searching through levels is not cheap in the context of inference, especially with strings).
Currently, I'm returning the categories in the arrow columnar format with the help of
We accept only strings and some other primitive types like integers. Still working on the typing part. The |
We are working on automatic re-encoding for categorical features during inference. This teaches the booster to handle data encoded differently than the training dataset and eliminates the need for a scikit-learn pipeline for data encoding when using DataFrame inputs.
PySparkScala/SparkRemoved the spark variants, its dataframe doesn't have encoding. Use the
StringIndexer
instead.Related:
Notes:
Looking into the Arrow CPU implementation, its compute module dispatches based on whether a null mask is present. If true, it tries to find consecutive valid values (called a run) and then iterates on this run. This way, it avoids having a predicate for every element for the validity check. The consecutive valid values are found using compiler builtins with leading nnz counting.
Tracking PRs:
ProxyDMatrix
creation keeps data until next iteration #11092The text was updated successfully, but these errors were encountered: