Skip to content

Commit

Permalink
GridSearch n_jobs support
Browse files Browse the repository at this point in the history
  • Loading branch information
jbogaardt committed Aug 29, 2021
1 parent 81d0c81 commit 3702fb6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion chainladder/core/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def heatmap(self, cmap="coolwarm", low=0, high=0, axis=0, subset=None):
rank_size = data.rank(axis=axis).max(axis=axis)
gmap = (raw_rank-1).div(rank_size-1, axis=not axis)*(shape_size-1) + 1
gmap = gmap.replace(np.nan, (shape_size+1)/2)
if float(pd.__version__[:3]) >= 1.3:
if pd.__version__ >= '1.3':
default_output = (
data.style.format(fmt_str).background_gradient(
cmap=cmap,
Expand Down
27 changes: 21 additions & 6 deletions chainladder/workflow/gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from sklearn.model_selection import ParameterGrid
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline as PipelineSL
from sklearn.base import clone
from chainladder.core.io import EstimatorIO
import copy
from joblib import Parallel, delayed
import pandas as pd
import json

Expand Down Expand Up @@ -41,6 +42,12 @@ class GridSearch(BaseEstimator):
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error. Default is 'raise' but from
version 0.22 it will change to np.nan.
n_jobs : int, default=None
The number of jobs to use for the computation. This will only provide
speedup for n_targets > 1 and sufficient large problems.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
Attributes
----------
Expand All @@ -49,12 +56,14 @@ class GridSearch(BaseEstimator):
score as the last column
"""

def __init__(self, estimator, param_grid, scoring, verbose=0, error_score="raise"):
def __init__(self, estimator, param_grid, scoring, verbose=0,
error_score="raise", n_jobs=None):
self.estimator = estimator
self.param_grid = param_grid
self.scoring = scoring
self.verbose = verbose
self.error_score = error_score
self.n_jobs = n_jobs

def fit(self, X, y=None, **fit_params):
"""Fit the model with X.
Expand All @@ -77,17 +86,23 @@ def fit(self, X, y=None, **fit_params):
else:
scoring = self.scoring
grid = list(ParameterGrid(self.param_grid))
results_ = []
for num, item in enumerate(grid):
est = copy.deepcopy(self.estimator).set_params(**item)

def _fit_single_estimator(estimator, fit_params, X, y, scoring, item):
est = clone(estimator).set_params(**item)
model = est.fit(X, y, **fit_params)
for score in scoring.keys():
item[score] = scoring[score](model)
results_.append(item)
return item

results_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_single_estimator)(
self.estimator, fit_params, X, y, scoring, item)
for item in grid)
self.results_ = pd.DataFrame(results_)
return self




class Pipeline(PipelineSL, EstimatorIO):
"""This is a near direct of copy the scikit-learn Pipeline class.
Expand Down

0 comments on commit 3702fb6

Please sign in to comment.