Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickle SR3 and subordinate classes #586

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
hooks:
- id: end-of-file-fixer
exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/)
stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
- id: trailing-whitespace
stages: [commit, merge-commit, push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
stages: [pre-commit, pre-merge-commit, pre-push, prepare-commit-msg, commit-msg, post-checkout, post-commit, post-merge, post-rewrite]
exclude: (.txt|^docs/JOSS1|^docs/JOSS2|^examples/data/)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
readme = "README.rst"
dependencies = [
"jax>=0.4,<0.5",
"scikit-learn>=1.1, !=1.5.0",
"scikit-learn>=1.1, !=1.5.0, !=1.6.0",
"derivative>=0.6.2",
"typing_extensions",
]
Expand Down
136 changes: 73 additions & 63 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from functools import wraps
from typing import Callable
from typing import Sequence
from typing import Union
Expand Down Expand Up @@ -152,17 +153,19 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def _validate_prox_and_reg_inputs(func, regularization):
def _validate_prox_and_reg_inputs(func):
"""Add guard code to ensure weight and argument have compatible shape/type

Decorates prox and regularization functions.
"""

@wraps(func)
def wrapper(x, regularization_weight):
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
raise ValueError(
f"'regularization_weight' must be an array of shape {x.shape}."
)
if isinstance(regularization_weight, np.ndarray):
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
raise ValueError(
f"Invalid shape for 'regularization_weight':"
f"Invalid shape for 'regularization_weight': "
f"{weight_shape}. Must be the same shape as x: {x.shape}."
)
elif not isinstance(regularization_weight, (int, float)):
Expand Down Expand Up @@ -190,36 +193,66 @@ def get_prox(
and returns an array of the same shape
"""

def prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)
prox = {
"l0": _prox_l0,
"weighted_l0": _prox_l0,
"l1": _prox_l1,
"weighted_l1": _prox_l1,
"l2": _prox_l2,
"weighted_l2": _prox_l2,
}
regularization = regularization.lower()
return prox[regularization]

def prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)
@_validate_prox_and_reg_inputs
def _prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return x / (1 + 2 * regularization_weight)

prox = {
"l0": prox_l0,
"weighted_l0": prox_l0,
"l1": prox_l1,
"weighted_l1": prox_l1,
"l2": prox_l2,
"weighted_l2": prox_l2,
}
regularization = regularization.lower()
return _validate_prox_and_reg_inputs(prox[regularization], regularization)
@_validate_prox_and_reg_inputs
def _prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)


@_validate_prox_and_reg_inputs
def _prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return x / (1 + 2 * regularization_weight)


@_validate_prox_and_reg_inputs
def _regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * (x != 0))


@_validate_prox_and_reg_inputs
def _regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * np.abs(x))


@_validate_prox_and_reg_inputs
def _regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * x**2)


def get_regularization(
Expand All @@ -238,39 +271,16 @@ def get_regularization(
and returns a float
"""

def regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * (x != 0))

def regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * np.abs(x))

def regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * x**2)

regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regularization_l0,
"l1": regularization_l1,
"weighted_l1": regularization_l1,
"l2": regularization_l2,
"weighted_l2": regularization_l2,
"l0": _regularization_l0,
"weighted_l0": _regularization_l0,
"l1": _regularization_l1,
"weighted_l1": _regularization_l1,
"l2": _regularization_l2,
"weighted_l2": _regularization_l2,
}
regularization = regularization.lower()
return _validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)
return regularization_fn[regularization]


def capped_simplex_projection(trimming_array, trimming_fraction):
Expand Down
20 changes: 20 additions & 0 deletions test/test_optimizers/debug.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this debug file be included?

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pickle
from functools import wraps


def foo(func):
@wraps(func)
def wrapper(*args, **kwargs):
print(f"Called {func}")
return func(*args, **kwargs)

return wrapper


@foo
def bar(a, b):
print(a + b)


bars = pickle.dumps(bar)
barl = pickle.loads(bars)
4 changes: 3 additions & 1 deletion test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,12 +1183,14 @@ def test_remove_and_decrement():
(
(MIOSR, {"target_sparsity": 7}),
(SBR, {"num_warmup": 10, "num_samples": 10}),
(SR3, {}),
(TrappingSR3, {"_n_tgts": 3, "_include_bias": True}),
),
)
def test_pickle(data_lorenz, opt_cls, opt_args):
x, t = data_lorenz
y = PolynomialLibrary(degree=2).fit_transform(x)
opt = opt_cls(**opt_args).fit(x, y)
opt = opt_cls(**opt_args).fit(y, x)
expected = opt.coef_
new_opt = pickle.loads(pickle.dumps(opt))
result = new_opt.coef_
Expand Down
14 changes: 3 additions & 11 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,9 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam):
prox(data, lam)


@pytest.mark.parametrize(
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize(
"lam",
[
np.array([[1, 2]]),
1,
],
)
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
def test_get_weighted_prox_and_regularization_bad_shape(regularization):
lam = np.array([[1, 2]])
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
with pytest.raises(ValueError):
Expand Down
Loading