diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 659986ca..c3a40235 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: hooks: - id: end-of-file-fixer exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/) - stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite] + stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite] - id: trailing-whitespace - stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite] + stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite] exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/) diff --git a/pyproject.toml b/pyproject.toml index cd3603a7..18f6b6fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ readme = "README.rst" dependencies = [ "jax>=0.4,<0.5", - "scikit-learn>=1.1, !=1.5.0", + "scikit-learn>=1.1, !=1.5.0, !=1.6.0", "derivative>=0.6.2", "typing_extensions", ] diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 95cbf40a..a2ddc840 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -1,4 +1,5 @@ import warnings +from functools import wraps from typing import Callable from typing import Sequence from typing import Union @@ -152,17 +153,19 @@ def reorder_constraints(arr, n_features, output_order="feature"): return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1)) -def _validate_prox_and_reg_inputs(func, regularization): +def _validate_prox_and_reg_inputs(func): + """Add guard code to ensure weight and argument have compatible shape/type + + Decorates prox and regularization functions. + """ + + @wraps(func) def wrapper(x, regularization_weight): - if regularization[:8] == "weighted": - if not isinstance(regularization_weight, np.ndarray): - raise ValueError( - f"'regularization_weight' must be an array of shape {x.shape}." - ) + if isinstance(regularization_weight, np.ndarray): weight_shape = regularization_weight.shape if weight_shape != x.shape: raise ValueError( - f"Invalid shape for 'regularization_weight':" + f"Invalid shape for 'regularization_weight': " f"{weight_shape}. Must be the same shape as x: {x.shape}." ) elif not isinstance(regularization_weight, (int, float)): @@ -190,36 +193,66 @@ def get_prox( and returns an array of the same shape """ - def prox_l0( - x: NDArray[np.float64], - regularization_weight: Union[float, NDArray[np.float64]], - ): - threshold = np.sqrt(2 * regularization_weight) - return x * (np.abs(x) > threshold) + prox = { + "l0": _prox_l0, + "weighted_l0": _prox_l0, + "l1": _prox_l1, + "weighted_l1": _prox_l1, + "l2": _prox_l2, + "weighted_l2": _prox_l2, + } + regularization = regularization.lower() + return prox[regularization] - def prox_l1( - x: NDArray[np.float64], - regularization_weight: Union[float, NDArray[np.float64]], - ): - return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0) +@_validate_prox_and_reg_inputs +def _prox_l0( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], +): + threshold = np.sqrt(2 * regularization_weight) + return x * (np.abs(x) > threshold) - def prox_l2( - x: NDArray[np.float64], - regularization_weight: Union[float, NDArray[np.float64]], - ): - return x / (1 + 2 * regularization_weight) - prox = { - "l0": prox_l0, - "weighted_l0": prox_l0, - "l1": prox_l1, - "weighted_l1": prox_l1, - "l2": prox_l2, - "weighted_l2": prox_l2, - } - regularization = regularization.lower() - return _validate_prox_and_reg_inputs(prox[regularization], regularization) +@_validate_prox_and_reg_inputs +def _prox_l1( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], +): + + return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0) + + +@_validate_prox_and_reg_inputs +def _prox_l2( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], +): + return x / (1 + 2 * regularization_weight) + + +@_validate_prox_and_reg_inputs +def _regularization_l0( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], +): + return np.sum(regularization_weight * (x != 0)) + + +@_validate_prox_and_reg_inputs +def _regularization_l1( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], +): + return np.sum(regularization_weight * np.abs(x)) + + +@_validate_prox_and_reg_inputs +def _regularization_l2( + x: NDArray[np.float64], + regularization_weight: Union[float, NDArray[np.float64]], +): + return np.sum(regularization_weight * x**2) def get_regularization( @@ -238,39 +271,16 @@ def get_regularization( and returns a float """ - def regularization_l0( - x: NDArray[np.float64], - regularization_weight: Union[float, NDArray[np.float64]], - ): - - return np.sum(regularization_weight * (x != 0)) - - def regularization_l1( - x: NDArray[np.float64], - regularization_weight: Union[float, NDArray[np.float64]], - ): - - return np.sum(regularization_weight * np.abs(x)) - - def regularization_l2( - x: NDArray[np.float64], - regularization_weight: Union[float, NDArray[np.float64]], - ): - - return np.sum(regularization_weight * x**2) - regularization_fn = { - "l0": regularization_l0, - "weighted_l0": regularization_l0, - "l1": regularization_l1, - "weighted_l1": regularization_l1, - "l2": regularization_l2, - "weighted_l2": regularization_l2, + "l0": _regularization_l0, + "weighted_l0": _regularization_l0, + "l1": _regularization_l1, + "weighted_l1": _regularization_l1, + "l2": _regularization_l2, + "weighted_l2": _regularization_l2, } regularization = regularization.lower() - return _validate_prox_and_reg_inputs( - regularization_fn[regularization], regularization - ) + return regularization_fn[regularization] def capped_simplex_projection(trimming_array, trimming_fraction): diff --git a/test/test_optimizers/debug.py b/test/test_optimizers/debug.py new file mode 100644 index 00000000..01901a2d --- /dev/null +++ b/test/test_optimizers/debug.py @@ -0,0 +1,20 @@ +import pickle +from functools import wraps + + +def foo(func): + @wraps(func) + def wrapper(*args, **kwargs): + print(f"Called {func}") + return func(*args, **kwargs) + + return wrapper + + +@foo +def bar(a, b): + print(a + b) + + +bars = pickle.dumps(bar) +barl = pickle.loads(bars) diff --git a/test/test_optimizers/test_optimizers.py b/test/test_optimizers/test_optimizers.py index 3ba9bf28..fb2102f6 100644 --- a/test/test_optimizers/test_optimizers.py +++ b/test/test_optimizers/test_optimizers.py @@ -1183,12 +1183,14 @@ def test_remove_and_decrement(): ( (MIOSR, {"target_sparsity": 7}), (SBR, {"num_warmup": 10, "num_samples": 10}), + (SR3, {}), + (TrappingSR3, {"_n_tgts": 3, "_include_bias": True}), ), ) def test_pickle(data_lorenz, opt_cls, opt_args): x, t = data_lorenz y = PolynomialLibrary(degree=2).fit_transform(x) - opt = opt_cls(**opt_args).fit(x, y) + opt = opt_cls(**opt_args).fit(y, x) expected = opt.coef_ new_opt = pickle.loads(pickle.dumps(opt)) result = new_opt.coef_ diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index bfd5054d..a70e44d4 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -101,17 +101,9 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam): prox(data, lam) -@pytest.mark.parametrize( - "regularization", ["weighted_l0", "weighted_l1", "weighted_l2"] -) -@pytest.mark.parametrize( - "lam", - [ - np.array([[1, 2]]), - 1, - ], -) -def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam): +@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"]) +def test_get_weighted_prox_and_regularization_bad_shape(regularization): + lam = np.array([[1, 2]]) data = np.array([[-2, 5]]).T reg = get_regularization(regularization) with pytest.raises(ValueError):