diff --git a/src/eddymotion/model/gpr.py b/src/eddymotion/model/gpr.py index 6798d4bd..572456b3 100644 --- a/src/eddymotion/model/gpr.py +++ b/src/eddymotion/model/gpr.py @@ -28,20 +28,25 @@ from typing import Callable, Mapping, Sequence import numpy as np +from ConfigSpace import Configuration from scipy import optimize from scipy.optimize._minimize import Bounds +from sklearn.base import clone from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( Hyperparameter, Kernel, ) from sklearn.metrics.pairwise import cosine_similarity +from sklearn.model_selection import RepeatedKFold, cross_val_score from sklearn.utils._param_validation import Interval, StrOptions BOUNDS_A: tuple[float, float] = (0.1, 2.35) """The limits for the parameter *a* (angular distance in rad).""" BOUNDS_LAMBDA: tuple[float, float] = (1e-3, 1000) """The limits for the parameter λ (signal scaling factor).""" +BOUNDS_ALPHA: tuple[float, float] = (1e-3, 500) +"""The limits for the parameter σ² (noise adjustment, alpha in Scikit-learn's GP regressor).""" THETA_EPSILON: float = 1e-5 """Minimum nonzero angle.""" LBFGS_CONFIGURABLE_OPTIONS = {"disp", "maxiter", "ftol", "gtol"} @@ -49,6 +54,7 @@ CONFIGURABLE_OPTIONS: Mapping[str, set] = { "Nelder-Mead": {"disp", "maxiter", "adaptive", "fatol"}, "CG": {"disp", "maxiter", "gtol"}, + "cross-validation": {"scoring", "n_folds", "n_evaluations"}, } """ A mapping from optimizer names to the option set they allow. @@ -161,6 +167,9 @@ class EddyMotionGPR(GaussianProcessRegressor): "normalize_y": ["boolean"], "n_targets": [Interval(Integral, 1, None, closed="left"), None], "random_state": ["random_state"], + "n_folds": [Interval(Integral, 3, None, closed="left")], + "n_evaluations": [Interval(Integral, 3, None, closed="left")], + "n_trials": [Interval(Integral, 3, None, closed="left")], } def __init__( @@ -182,6 +191,10 @@ def __init__( gtol: float | None = None, adaptive: bool | int | None = None, fatol: float | None = None, + scoring: str = "neg_root_mean_squared_error", + n_folds: int | None = 10, + n_evaluations: int | None = 40, + n_trials: int | None = 200, ): super().__init__( kernel, @@ -202,6 +215,10 @@ def __init__( self.gtol = gtol self.adaptive = adaptive self.fatol = fatol + self.scoring = scoring + self.n_folds = n_folds + self.n_evaluations = n_evaluations + self.n_trials = n_trials def _constrained_optimization( self, @@ -210,6 +227,40 @@ def _constrained_optimization( bounds: Sequence[tuple[float, float]] | Bounds, ) -> tuple[float, float]: options = {} + + if self.optimizer == "cross-validation": + from ConfigSpace import ConfigurationSpace, Float + from smac import HyperparameterOptimizationFacade, Scenario + + cs = ConfigurationSpace() + beta_a = Float( + "kernel__beta_a", + tuple(self.kernel.a_bounds), + default=self.kernel_.beta_a, + log=True, + ) + beta_l = Float( + "kernel__beta_l", + tuple(self.kernel.l_bounds), + default=self.kernel_.beta_l, + log=True, + ) + cs.add([beta_a, beta_l]) + + # Scenario object specifying the optimization environment + scenario = Scenario(cs, n_trials=self.n_trials) + + # Use SMAC to find the best configuration/hyperparameters + smac = HyperparameterOptimizationFacade( + scenario, + self.cross_validation, + ) + incumbent = smac.optimize() + return ( + np.log([incumbent["kernel__beta_a"], incumbent["kernel__beta_l"]]), + 0, + ) + if self.optimizer == "fmin_l_bfgs_b": from sklearn.utils.optimize import _check_optimize_result @@ -252,6 +303,27 @@ def _constrained_optimization( raise ValueError(f"Unknown optimizer {self.optimizer}.") + def cross_validation( + self, + config: Configuration, + seed: int | None = None, + ) -> float: + rkf = RepeatedKFold( + n_splits=self.n_folds, + n_repeats=max(self.n_evaluations // self.n_folds, 1), + ) + gpr = clone(self) + gpr.set_params(**dict(config)) + gpr.optimizer = None + scores = cross_val_score( + gpr, + self.X_train_, + self.y_train_, + scoring=self.scoring, + cv=rkf, + ) + return np.mean(scores) + class ExponentialKriging(Kernel): """A scikit-learn's kernel for DWI signals."""