Skip to content

Commit

Permalink
feature importance type in saved model file (#3220)
Browse files Browse the repository at this point in the history
* feature importance type in saved model file

* fix nullptr

* fixed formatting

* fix python/R

* Update src/c_api.cpp

* Apply suggestions from code review

Co-authored-by: James Lamb <[email protected]>

* fix c_api test

* fix swig

* minor docs improvements and added defines for importance types

Co-authored-by: StrikerRUS <[email protected]>
Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
3 people authored Jul 15, 2020
1 parent 7b8b515 commit 87d4648
Show file tree
Hide file tree
Showing 16 changed files with 132 additions and 51 deletions.
9 changes: 6 additions & 3 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ Booster <- R6::R6Class(
},

# Save model
save_model = function(filename, num_iteration = NULL) {
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
Expand All @@ -437,6 +437,7 @@ Booster <- R6::R6Class(
, ret = NULL
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
, lgb.c_str(filename)
)

Expand All @@ -445,7 +446,7 @@ Booster <- R6::R6Class(
},

# Save model to string
save_model_to_string = function(num_iteration = NULL) {
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
Expand All @@ -457,12 +458,13 @@ Booster <- R6::R6Class(
"LGBM_BoosterSaveModelToString_R"
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
))

},

# Dump model in memory
dump_model = function(num_iteration = NULL) {
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
Expand All @@ -474,6 +476,7 @@ Booster <- R6::R6Class(
"LGBM_BoosterDumpModel_R"
, private$handle
, as.integer(num_iteration)
, as.integer(feature_importance_type)
)

},
Expand Down
15 changes: 9 additions & 6 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,37 +632,40 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,

LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename,
LGBM_SE call_state) {
R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_CHAR_PTR(filename)));
R_API_END();
}

LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}

LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
LGBM_SE call_state) {
R_API_BEGIN();
int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(feature_importance_type), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END();
}
Expand Down Expand Up @@ -707,9 +710,9 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterCalcNumPredict_R" , (DL_FUNC) &LGBM_BoosterCalcNumPredict_R , 8},
{"LGBM_BoosterPredictForCSC_R" , (DL_FUNC) &LGBM_BoosterPredictForCSC_R , 14},
{"LGBM_BoosterPredictForMat_R" , (DL_FUNC) &LGBM_BoosterPredictForMat_R , 11},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 4},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 6},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 6},
{"LGBM_BoosterSaveModel_R" , (DL_FUNC) &LGBM_BoosterSaveModel_R , 5},
{"LGBM_BoosterSaveModelToString_R" , (DL_FUNC) &LGBM_BoosterSaveModelToString_R , 7},
{"LGBM_BoosterDumpModel_R" , (DL_FUNC) &LGBM_BoosterDumpModel_R , 7},
{NULL, NULL, 0}
};

Expand Down
3 changes: 3 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE filename,
LGBM_SE call_state
);
Expand All @@ -604,6 +605,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModel_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
Expand All @@ -620,6 +622,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterSaveModelToString_R(
LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterDumpModel_R(
LGBM_SE handle,
LGBM_SE num_iteration,
LGBM_SE feature_importance_type,
LGBM_SE buffer_len,
LGBM_SE actual_len,
LGBM_SE out_str,
Expand Down
8 changes: 8 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,14 @@ Learning Control Parameters

- **Note**: can be used only in CLI version

- ``saved_feature_importance_type`` :raw-html:`<a id="saved_feature_importance_type" title="Permalink to this parameter" href="#saved_feature_importance_type">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int

- the feature importance type in the saved model file

- ``0``: count-based feature importance (numbers of splits are counted); ``1``: gain-based feature importance (values of gain are counted)

- **Note**: can be used only in CLI version

- ``snapshot_freq`` :raw-html:`<a id="snapshot_freq" title="Permalink to this parameter" href="#snapshot_freq">&#x1F517;&#xFE0E;</a>`, default = ``-1``, type = int, aliases: ``save_period``

- frequency of saving model file snapshot
Expand Down
10 changes: 6 additions & 4 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Json format string of model
*/
virtual std::string DumpModel(int start_iteration, int num_iteration) const = 0;
virtual std::string DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const = 0;

/*!
* \brief Translate model to if-else statement
Expand All @@ -199,19 +200,20 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \param filename Filename that want to save to
* \return true if succeeded
*/
virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const = 0;
virtual bool SaveModelToFile(int start_iteration, int num_iterations, int feature_importance_type, const char* filename) const = 0;

/*!
* \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all
* \param feature_importance_type Type of feature importance, 0: split, 1: gain
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int start_iteration, int num_iterations) const = 0;
virtual std::string SaveModelToString(int start_iteration, int num_iterations, int feature_importance_type) const = 0;

/*!
* \brief Restore from a serialized string
Expand Down
13 changes: 11 additions & 2 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ typedef void* BoosterHandle; /*!< \brief Handle of booster. */
#define C_API_MATRIX_TYPE_CSR (0) /*!< \brief CSR sparse matrix type. */
#define C_API_MATRIX_TYPE_CSC (1) /*!< \brief CSC sparse matrix type. */

#define C_API_FEATURE_IMPORTANCE_SPLIT (0) /*!< \brief Split type of feature importance. */
#define C_API_FEATURE_IMPORTANCE_GAIN (1) /*!< \brief Gain type of feature importance. */

/*!
* \brief Get string message of the last error.
* \return Error information
Expand Down Expand Up @@ -996,19 +999,22 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMats(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param filename The name of the file
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
const char* filename);

/*!
* \brief Save model to string.
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be saved
* \param num_iteration Index of the iteration that should be saved, <= 0 means save all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str String of model, should pre-allocate memory
Expand All @@ -1017,6 +1023,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
Expand All @@ -1026,6 +1033,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \param handle Handle of booster
* \param start_iteration Start index of the iteration that should be dumped
* \param num_iteration Index of the iteration that should be dumped, <= 0 means dump all
* \param feature_importance_type Type of feature importance, can be ``C_API_FEATURE_IMPORTANCE_SPLIT`` or ``C_API_FEATURE_IMPORTANCE_GAIN``
* \param buffer_len String buffer length, if ``buffer_len < out_len``, you should re-allocate buffer
* \param[out] out_len Actual output length
* \param[out] out_str JSON format string of model, should pre-allocate memory
Expand All @@ -1034,6 +1042,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration,
int num_iteration,
int feature_importance_type,
int64_t buffer_len,
int64_t* out_len,
char* out_str);
Expand Down Expand Up @@ -1069,8 +1078,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
* \param handle Handle of booster
* \param num_iteration Number of iterations for which feature importance is calculated, <= 0 means use all
* \param importance_type Method of importance calculation:
* - 0 for split, result contains numbers of times the feature is used in a model;
* - 1 for gain, result contains total gains of splits which use the feature
* - ``C_API_FEATURE_IMPORTANCE_SPLIT``: result contains numbers of times the feature is used in a model;
* - ``C_API_FEATURE_IMPORTANCE_GAIN``: result contains total gains of splits which use the feature
* \param[out] out_results Result array with feature importance
* \return 0 when succeed, -1 when failure happens
*/
Expand Down
5 changes: 5 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,11 @@ struct Config {
// desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt";

// desc = the feature importance type in the saved model file
// desc = ``0``: count-based feature importance (numbers of splits are counted); ``1``: gain-based feature importance (values of gain are counted)
// desc = **Note**: can be used only in CLI version
int saved_feature_importance_type = 0;

// [no-save]
// alias = save_period
// desc = frequency of saving model file snapshot
Expand Down
Loading

0 comments on commit 87d4648

Please sign in to comment.