|
206 | 206 | " max_samples = None\n",
|
207 | 207 | " n_calibration_folds = 4 # 4 uses all cores on the containers\n",
|
208 | 208 | "\n",
|
| 209 | + " from interpret.develop import set_option\n", |
209 | 210 | " from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor\n",
|
210 | 211 | " from xgboost import XGBClassifier, XGBRegressor, XGBRFClassifier, XGBRFRegressor\n",
|
211 | 212 | " from lightgbm import LGBMClassifier, LGBMRegressor\n",
|
|
348 | 349 | " # Specify method\n",
|
349 | 350 | " if trial.task.problem in [\"binary\", \"multiclass\"]:\n",
|
350 | 351 | " if trial.method == \"ebm\" or trial.method == \"ebm_opt\" and trial.task.name in ebm_classification_non_opt:\n",
|
| 352 | + " for param, val in ebm_params.copy().items():\n", |
| 353 | + " try:\n", |
| 354 | + " set_option(param, val)\n", |
| 355 | + " del ebm_params[param]\n", |
| 356 | + " except:\n", |
| 357 | + " pass\n", |
351 | 358 | " est = ExplainableBoostingClassifier(**ebm_params)\n",
|
352 | 359 | " elif trial.method == \"ebm_opt\":\n",
|
| 360 | + " for param, val in ebm_params.copy().items():\n", |
| 361 | + " try:\n", |
| 362 | + " set_option(param, val)\n", |
| 363 | + " del ebm_params[param]\n", |
| 364 | + " except:\n", |
| 365 | + " pass\n", |
353 | 366 | " # TODO: change these optimization parameters\n",
|
354 | 367 | " param_grid = {\n",
|
355 | 368 | " 'smoothing_rounds': optuna.distributions.IntDistribution(1, 4000, log=True),\n",
|
|
513 | 526 | " raise Exception(f\"Unrecognized classification method name {trial.method}\")\n",
|
514 | 527 | " elif trial.task.problem == \"regression\":\n",
|
515 | 528 | " if trial.method == \"ebm\":\n",
|
| 529 | + " for param, val in ebm_params.copy().items():\n", |
| 530 | + " try:\n", |
| 531 | + " set_option(param, val)\n", |
| 532 | + " del ebm_params[param]\n", |
| 533 | + " except:\n", |
| 534 | + " pass\n", |
516 | 535 | " est = ExplainableBoostingRegressor(**ebm_params)\n",
|
517 | 536 | " elif trial.method == \"ebm_opt\":\n",
|
| 537 | + " for param, val in ebm_params.copy().items():\n", |
| 538 | + " try:\n", |
| 539 | + " set_option(param, val)\n", |
| 540 | + " del ebm_params[param]\n", |
| 541 | + " except:\n", |
| 542 | + " pass\n", |
518 | 543 | " if trial.task.name in {\"Allstate_Claims_Severity\"}:\n",
|
519 | 544 | " # TODO: tweak\n",
|
520 | 545 | " max_samples = 5000 # crashes or fit time too long without subsampling\n",
|
|
0 commit comments