Skip to content

Commit

Permalink
remove unused input in scan
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Oct 23, 2023
1 parent 51cf08a commit ae4ec10
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def window_adaptation(
)

def one_step(carry, xs):
_, rng_key, adaptation_stage = xs
rng_key, adaptation_stage = xs
state, adaptation_state = carry

new_state, info = mcmc_kernel(
Expand Down Expand Up @@ -335,7 +335,7 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
(jnp.arange(num_steps), keys, schedule),
(keys, schedule),
)
last_chain_state, last_warmup_state, *_ = last_state

Expand Down

0 comments on commit ae4ec10

Please sign in to comment.