Skip to content

Commit 14b79a9

Browse files
committed
add ability to specify development options in the EBM benchmarking notebook
1 parent 8e22d6b commit 14b79a9

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

docs/benchmarks/ebm-benchmark.ipynb

+25
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@
206206
" max_samples = None\n",
207207
" n_calibration_folds = 4 # 4 uses all cores on the containers\n",
208208
"\n",
209+
" from interpret.develop import set_option\n",
209210
" from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor\n",
210211
" from xgboost import XGBClassifier, XGBRegressor, XGBRFClassifier, XGBRFRegressor\n",
211212
" from lightgbm import LGBMClassifier, LGBMRegressor\n",
@@ -348,8 +349,20 @@
348349
" # Specify method\n",
349350
" if trial.task.problem in [\"binary\", \"multiclass\"]:\n",
350351
" if trial.method == \"ebm\" or trial.method == \"ebm_opt\" and trial.task.name in ebm_classification_non_opt:\n",
352+
" for param, val in ebm_params.copy().items():\n",
353+
" try:\n",
354+
" set_option(param, val)\n",
355+
" del ebm_params[param]\n",
356+
" except:\n",
357+
" pass\n",
351358
" est = ExplainableBoostingClassifier(**ebm_params)\n",
352359
" elif trial.method == \"ebm_opt\":\n",
360+
" for param, val in ebm_params.copy().items():\n",
361+
" try:\n",
362+
" set_option(param, val)\n",
363+
" del ebm_params[param]\n",
364+
" except:\n",
365+
" pass\n",
353366
" # TODO: change these optimization parameters\n",
354367
" param_grid = {\n",
355368
" 'smoothing_rounds': optuna.distributions.IntDistribution(1, 4000, log=True),\n",
@@ -513,8 +526,20 @@
513526
" raise Exception(f\"Unrecognized classification method name {trial.method}\")\n",
514527
" elif trial.task.problem == \"regression\":\n",
515528
" if trial.method == \"ebm\":\n",
529+
" for param, val in ebm_params.copy().items():\n",
530+
" try:\n",
531+
" set_option(param, val)\n",
532+
" del ebm_params[param]\n",
533+
" except:\n",
534+
" pass\n",
516535
" est = ExplainableBoostingRegressor(**ebm_params)\n",
517536
" elif trial.method == \"ebm_opt\":\n",
537+
" for param, val in ebm_params.copy().items():\n",
538+
" try:\n",
539+
" set_option(param, val)\n",
540+
" del ebm_params[param]\n",
541+
" except:\n",
542+
" pass\n",
518543
" if trial.task.name in {\"Allstate_Claims_Severity\"}:\n",
519544
" # TODO: tweak\n",
520545
" max_samples = 5000 # crashes or fit time too long without subsampling\n",

0 commit comments

Comments
 (0)