Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add (Prox-)SVRG as a solver for GLMs #184

Merged
merged 198 commits into from
Aug 15, 2024

Conversation

bagibence
Copy link
Collaborator

Add a preliminary implementation of SVRG and Prox-SVRG compatible with the jaxopt API, so nemos can use it as a solver:

  • the run method (used in GLM.fit) assumes the full data is loaded into memory and runs optimization until convergence.
  • the update method (used in GLM.update) updates the parameters on a single data point or batch. Note that currently, when fitting on batches and calling glm.update, some steps of the algorithm have to be implemented outside.

ProxSVRG takes a prox argument for the proximal operator which should be a callable with the same signature as the functions defined in jaxopt.prox (e.g. jaxopt.prox.prox_lasso).
SVRG inherits from ProxSVRG, uses the identity as the proximal operator and adjusts some calls to account for the different number of arguments.

bagibence added 30 commits June 21, 2024 17:22
Copy link
Member

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks very good! Mainly some minor changes here, and I responded in a couple of my previous comments, so make sure to check those again.

Also, and I know this is annoying, but test_solvers still has references to xs and df_xs. I think that should be updated to match the new nomenclature, which will make it easier for us in the future.

tests/test_basis.py Outdated Show resolved Hide resolved
src/nemos/solvers.py Outdated Show resolved Hide resolved
src/nemos/solvers.py Outdated Show resolved Hide resolved
src/nemos/solvers.py Show resolved Hide resolved
src/nemos/solvers.py Outdated Show resolved Hide resolved
src/nemos/solvers.py Outdated Show resolved Hide resolved
src/nemos/solvers.py Outdated Show resolved Hide resolved
Copy link
Member

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small issues and then this is ready to merge!

  • Docstring of ProxSVRG mentions hyperparams_prox, which is no longer an argument. So that should be removed.

src/nemos/solvers.py Outdated Show resolved Hide resolved
@BalzaniEdoardo
Copy link
Collaborator

  • hyperparams_prox

The docstrings are ok, because that is referring to any hyperparameter that a general proximal operator gets as input. For example, what we call regularization_strength is one such hyper-parameter for ridge.

@BalzaniEdoardo BalzaniEdoardo merged commit 5684359 into flatironinstitute:development Aug 15, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants