Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

window_adaptation excessive memory usage #667

Closed
andrewdipper opened this issue May 11, 2024 · 2 comments
Closed

window_adaptation excessive memory usage #667

andrewdipper opened this issue May 11, 2024 · 2 comments

Comments

@andrewdipper
Copy link
Contributor

Describe the issue as clearly as possible:

The scan in window_adaptation by default saves the AdaptationInfo for every sample along the way. This results in memory usage many times in excess of (num_samples)*(num_variables) and leads to out of memory issues. However, it looks like the last states are the only information necessary to performing the window adaptation.

As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.

I believe #529 is the result of the same thing: The extra buffers are likely for storing the sample by sample info - I get similar outputs.

Thanks

Steps/code to reproduce the bug:

#a roughed alternative one_step function (in window_adaptation.py) that completely removes the 
#extra info - None is not the ideal solution here

    def one_step(carry, xs):
        _, rng_key, adaptation_stage = xs
        state, adaptation_state = carry

        new_state, info = mcmc_kernel(
            rng_key,
            state,
            logdensity_fn,
            adaptation_state.step_size,
            adaptation_state.inverse_mass_matrix,
            **extra_parameters,
        )
        new_adaptation_state = adapt_step(
            adaptation_state,
            adaptation_stage,
            new_state.position,
            info.acceptance_rate,
        )

        return (
            (new_state, new_adaptation_state),
            None, #-removed-# AdaptationInfo(new_state, info, new_adaptation_state),
        )

Expected result:

...

Error message:

...

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 0.1.dev494+g40efb6c.d20240511  # this is 1.2.1 with the jnp.clip PR removed
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
Jax 0.4.25
Jaxlib 0.4.25

Context for the issue:

I found this trying to reduce memory consumption for a pymc model sampled with blackjax - there's a similar issue there with storing extra info during the actual sampling process. With both fixes memory consumption and performance are initially looking more than comparable with pymc's numpyro sampler.

@junpenglao
Copy link
Member

junpenglao commented May 12, 2024

As such it'd be helpful to disable / select what info to store along the way such that the auxiliary info doesn't cause out of memory issues. Removing it altogether also doesn't seem ideal. I'd be happy to give a PR if you have an idea of what/how to best include/exclude the extra info.

Agree. I think numpryo also has a flag to control what get exposed. I think we will need to add a kwarg to

def window_adaptation(
algorithm,
logdensity_fn: Callable,
is_mass_matrix_diagonal: bool = True,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
**extra_parameters,
) -> AdaptationAlgorithm:

def return_all_adapt_info(state, info, adaptation_state):
    return AdaptationInfo(state, info, adaptation_state)

 def window_adaptation( 
     algorithm, 
     logdensity_fn: Callable, 
     is_mass_matrix_diagonal: bool = True, 
     initial_step_size: float = 1.0, 
     target_acceptance_rate: float = 0.80, 
     progress_bar: bool = False, 
     adaptation_info_fn: Callable = return_all_adapt_info
     **extra_parameters, 
 )

And then add some utility function for filtering common info (e.g., return_warmup_sample=True to return state)

@junpenglao
Copy link
Member

Really good point and thank you for the deep dive!! Feel free to send a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants