-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add (Prox-)SVRG as a solver for GLMs #184
Add (Prox-)SVRG as a solver for GLMs #184
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generally looks very good! Mainly some minor changes here, and I responded in a couple of my previous comments, so make sure to check those again.
Also, and I know this is annoying, but test_solvers
still has references to xs
and df_xs
. I think that should be updated to match the new nomenclature, which will make it easier for us in the future.
Co-authored-by: William F. Broderick <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two small issues and then this is ready to merge!
- Docstring of
ProxSVRG
mentionshyperparams_prox
, which is no longer an argument. So that should be removed.
Co-authored-by: William F. Broderick <[email protected]>
The docstrings are ok, because that is referring to any hyperparameter that a general proximal operator gets as input. For example, what we call regularization_strength is one such hyper-parameter for ridge. |
5684359
into
flatironinstitute:development
Add a preliminary implementation of SVRG and Prox-SVRG compatible with the jaxopt API, so nemos can use it as a solver:
run
method (used inGLM.fit
) assumes the full data is loaded into memory and runs optimization until convergence.update
method (used inGLM.update
) updates the parameters on a single data point or batch. Note that currently, when fitting on batches and callingglm.update
, some steps of the algorithm have to be implemented outside.ProxSVRG
takes aprox
argument for the proximal operator which should be a callable with the same signature as the functions defined injaxopt.prox
(e.g.jaxopt.prox.prox_lasso
).SVRG
inherits fromProxSVRG
, uses the identity as the proximal operator and adjusts some calls to account for the different number of arguments.