Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickle SR3 and subordinate classes #586

Merged
merged 4 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
hooks:
- id: end-of-file-fixer
exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/)
stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
- id: trailing-whitespace
stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
readme = "README.rst"
dependencies = [
"jax>=0.4,<0.5",
"scikit-learn>=1.1, !=1.5.0",
"scikit-learn>=1.1, !=1.5.0, !=1.6.0",
"derivative>=0.6.2",
"typing_extensions",
]
Expand Down
20 changes: 19 additions & 1 deletion pysindy/optimizers/ssr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from typing import cast
from typing import NewType
from typing import TypeVar

import numpy as np
from numpy.typing import NBitBase
from sklearn.linear_model import ridge_regression

from .base import BaseOptimizer

Rows = TypeVar("Rows", covariant=True, bound=int)
Cols = TypeVar("Cols", covariant=True, bound=int)
Float2D = np.ndarray[tuple[Rows, Cols], np.dtype[np.floating[NBitBase]]]
Features = NewType("Features", int)
Targets = NewType("Targets", int)
Samples = NewType("Samples", int)


class SSR(BaseOptimizer):
"""Stepwise sparse regression (SSR) greedy algorithm.
Expand Down Expand Up @@ -157,17 +169,23 @@ def _model_residual(self, x, y, coef, inds):
cc[total_ind] = 0.0
return cc, total_ind

def _regress(self, x, y):
def _regress(
self, x: Float2D[Samples, Features], y: Float2D[Samples, Targets]
) -> Float2D[Targets, Features]:
"""Perform the ridge regression"""
kw = self.ridge_kw or {}
coef = ridge_regression(x, y, self.alpha, **kw)
coef = np.atleast_2d(coef) # type: ignore
self.iters += 1
return coef

def _reduce(self, x, y):
"""Performs at most ``self.max_iter`` iterations of the
SSR greedy algorithm.
"""
# Until static typing grows, use cast
x = cast(Float2D[Samples, Features], x)
y = cast(Float2D[Samples, Targets], y)
n_samples, n_features = x.shape
n_targets = y.shape[1]
cond_num = np.linalg.cond(x)
Expand Down
136 changes: 73 additions & 63 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from functools import wraps
from typing import Callable
from typing import Sequence
from typing import Union
Expand Down Expand Up @@ -152,17 +153,19 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def _validate_prox_and_reg_inputs(func, regularization):
def _validate_prox_and_reg_inputs(func):
"""Add guard code to ensure weight and argument have compatible shape/type

Decorates prox and regularization functions.
"""

@wraps(func)
def wrapper(x, regularization_weight):
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
raise ValueError(
f"'regularization_weight' must be an array of shape {x.shape}."
)
if isinstance(regularization_weight, np.ndarray):
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
raise ValueError(
f"Invalid shape for 'regularization_weight':"
f"Invalid shape for 'regularization_weight': "
f"{weight_shape}. Must be the same shape as x: {x.shape}."
)
elif not isinstance(regularization_weight, (int, float)):
Expand Down Expand Up @@ -190,36 +193,66 @@ def get_prox(
and returns an array of the same shape
"""

def prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)
prox = {
"l0": _prox_l0,
"weighted_l0": _prox_l0,
"l1": _prox_l1,
"weighted_l1": _prox_l1,
"l2": _prox_l2,
"weighted_l2": _prox_l2,
}
regularization = regularization.lower()
return prox[regularization]

def prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)
@_validate_prox_and_reg_inputs
def _prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return x / (1 + 2 * regularization_weight)

prox = {
"l0": prox_l0,
"weighted_l0": prox_l0,
"l1": prox_l1,
"weighted_l1": prox_l1,
"l2": prox_l2,
"weighted_l2": prox_l2,
}
regularization = regularization.lower()
return _validate_prox_and_reg_inputs(prox[regularization], regularization)
@_validate_prox_and_reg_inputs
def _prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)


@_validate_prox_and_reg_inputs
def _prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return x / (1 + 2 * regularization_weight)


@_validate_prox_and_reg_inputs
def _regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * (x != 0))


@_validate_prox_and_reg_inputs
def _regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * np.abs(x))


@_validate_prox_and_reg_inputs
def _regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * x**2)


def get_regularization(
Expand All @@ -238,39 +271,16 @@ def get_regularization(
and returns a float
"""

def regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * (x != 0))

def regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * np.abs(x))

def regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * x**2)

regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regularization_l0,
"l1": regularization_l1,
"weighted_l1": regularization_l1,
"l2": regularization_l2,
"weighted_l2": regularization_l2,
"l0": _regularization_l0,
"weighted_l0": _regularization_l0,
"l1": _regularization_l1,
"weighted_l1": _regularization_l1,
"l2": _regularization_l2,
"weighted_l2": _regularization_l2,
}
regularization = regularization.lower()
return _validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)
return regularization_fn[regularization]


def capped_simplex_projection(trimming_array, trimming_fraction):
Expand Down
4 changes: 3 additions & 1 deletion test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,12 +1183,14 @@ def test_remove_and_decrement():
(
(MIOSR, {"target_sparsity": 7}),
(SBR, {"num_warmup": 10, "num_samples": 10}),
(SR3, {}),
(TrappingSR3, {"_n_tgts": 3, "_include_bias": True}),
),
)
def test_pickle(data_lorenz, opt_cls, opt_args):
x, t = data_lorenz
y = PolynomialLibrary(degree=2).fit_transform(x)
opt = opt_cls(**opt_args).fit(x, y)
opt = opt_cls(**opt_args).fit(y, x)
expected = opt.coef_
new_opt = pickle.loads(pickle.dumps(opt))
result = new_opt.coef_
Expand Down
14 changes: 3 additions & 11 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,9 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam):
prox(data, lam)


@pytest.mark.parametrize(
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize(
"lam",
[
np.array([[1, 2]]),
1,
],
)
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
def test_get_weighted_prox_and_regularization_bad_shape(regularization):
lam = np.array([[1, 2]])
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
with pytest.raises(ValueError):
Expand Down
Loading