-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update run_inference_algorithm
to split initial_position
and initial_state
#672
Update run_inference_algorithm
to split initial_position
and initial_state
#672
Conversation
run_inference_algorithm
to remove initial_position
as valid inputrun_inference_algorithm
to split initial_position
and initial_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, a few more nit
progress_bar: bool = False, | ||
transform: Callable = lambda x: x, | ||
) -> tuple[State, State, Info]: | ||
return_state_history=True, | ||
expectation: Callable = lambda x: x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we combine the kwargs of transform
and expectation
for now? Otherwise we need a better name for expectation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expectation
seems to me like an appropriate name, since indeed the value calculated is the expectation. For example, if expectation=lambda x: x**2
, then you get back
By contrast, transform
operates on the full history of samples. In the future, I think it would make sense for it to also take Info, so that the user can choose to dispense with (part of) the diagnostic info. That is why I thought it was better to keep them separate.
I think it would in theory be possible to make them the same, if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, but looking at the implementation, the return state is always transform(state)
, which means for a expectation transformation, we should probably do expectation(transform(x))
?
… into inference_algorithm
progress_bar: bool = False, | ||
transform: Callable = lambda x: x, | ||
) -> tuple[State, State, Info]: | ||
return_state_history=True, | ||
expectation: Callable = lambda x: x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, but looking at the implementation, the return state is always transform(state)
, which means for a expectation transformation, we should probably do expectation(transform(x))
?
The other argument returned, |
What i meant here is that:
Does this makes sense? |
Yeah, that makes sense. Updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
* Update README.md (blackjax-devs#638) * Update README.md Update citation. * Update README.md * Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640) Co-authored-by: Junpeng Lao <[email protected]> * Bump python version (blackjax-devs#645) * Bump python version * update bool inverse * SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649) * vmaping over parameters in base * switch from mcmc_factory to just passing in parameters * pre-commit and typing * CRU and docs improvement * pre-commit * code review updates * pre-commit * rename test * Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651) * Migrate from deprecated `host_callback` to `io_callback` Co-Authored-By: George Necula <[email protected]> * Format file * Fix bug * Fix MALA transition energy (blackjax-devs#653) * Fix MALA transition energy * Use a different logic. * Change variable names (blackjax-devs#654) * Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656) * Replace iterative RNG split and carry with `jax.random.fold_in` * revert unintended change * file formatting * change `jax.tree_map` to `jax.tree.map` * revert unintended file * fiddle with rng_key * seed again * Removal of Algorithm classes. (blackjax-devs#657) * more * removing export * removal of classes, tests passing * linter * fix on test * linter * removing parametrization on test * code review updates * exporting as_top_level_api in dynamic_hmc * linter * code review update: replace imports * Fix deprecated call to jnp.clip (blackjax-devs#664) * Update jax version requirements (blackjax-devs#666) Fix blackjax-devs#665 * Make tests pass on `aarch64-linux` (blackjax-devs#671) * Enable fitlering of AdaptationInfo (blackjax-devs#674) * enable AdaptationInfo filtering * revert progress_bar * fix pre-commit * fix empty sets * enable adapt info filtering for all adaptation algorithms * fix precommit /progressbar=True * change filter tuple to use tree_map * Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672) * UPDATE DOCSTRING * ADD STREAMING VERSION * UPDATE TESTS * ADD DOCSTRING * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * ADD INITIAL_POSITION * FIX TEST * RENAME O * FIX DOCSTRING * PUT EXPECTATION AFTER TRANSFORM * Preconditioned mclmc (blackjax-devs#673) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * ADD INITIAL_POSITION * FIX TEST * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * New integrator, and add some metadata to integrators.py (blackjax-devs#681) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS * TEMPORARILY ADD BENCHMARKS * ADD INITIAL_POSITION * FIX TEST * CLEAN UP * REMOVE BENCHMARKS * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * ADD OMELYAN TEST * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * MERGE MAIN * REMOVE COEFFICIENT EXPORTS * Minor formatting (blackjax-devs#685) * Minor formatting * formatting * fix test * formatting * MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687) * FIX KWARG BUG (blackjax-devs#686) * FIX KWARG BUG * FIX KWARG BUG * Change isokinetic_integrator generation API (blackjax-devs#689) * Apply function on pytree directly. (blackjax-devs#692) * Apply function on pytree directly. Avoiding unnecssary unpacking * Fix kwarg * Fix sampling test. (blackjax-devs#693) * Enable shared mcmc parameters with tempered smc (blackjax-devs#694) * add parameter filtering * fix parameter split + docstring * change extend_paramss * convert to bit twiddling (blackjax-devs#696) * Remove nightly release (blackjax-devs#699) * Fix doc mistakes (blackjax-devs#701) * Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao <[email protected]> * Update index.md (blackjax-devs#711) The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. * Enable progress bar under pmap (blackjax-devs#712) * enable pmap progbar * fix bar creation * add locking * fix formatting * switch to using chain state * remove labels (blackjax-devs#716) * Simplify `run_inference_algorithm` (blackjax-devs#714) * fix minor type errors * storing only expectation values * fixed memory efficient sampling * clean up * renaming vars * precommit fixes * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * merge main * burn in and fix tests * burn in and fix tests * minor fixes * minor fixes * minor fixes --------- Co-authored-by: [email protected] <[email protected]> * Harmonize Quickstart example (blackjax-devs#717) * Update README.md (blackjax-devs#719) --------- Co-authored-by: Junpeng Lao <[email protected]> Co-authored-by: Carlos Iguaran <[email protected]> Co-authored-by: ksnxr <[email protected]> Co-authored-by: Gaétan Lepage <[email protected]> Co-authored-by: Alberto Cabezas <[email protected]> Co-authored-by: andrewdipper <[email protected]> Co-authored-by: Reuben <[email protected]> Co-authored-by: Gilad Turok <[email protected]> Co-authored-by: johannahaffner <[email protected]> Co-authored-by: [email protected] <[email protected]>
Improvement to
run_inference_algorithm
Description:
run_inference_algorithm
currently uses a try-except clause to allow the user to either provide an initial position or an initial state. This has led to some problems, when theexcept
clause fails to trigger.run_inference_algorithm
also produces an array of[num_steps, n_dims]
, which for long chains on high-dim problems can be prohibitively large.Solution:
Make
run_inference_algorithm
takeinitial_state
orinitial_position
explicitly, so that the caller ofrun_inference_algorithm
is responsible for providinginitial_state
orinitial_position
, rather than deferring this task.Add a
streaming
mode forrun_inference_algorithm
.A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;