Skip to content

Commit

Permalink
Don't overload setparams\!\! with VarInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Nov 4, 2024
1 parent d52af52 commit 508ac61
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 48 deletions.
77 changes: 32 additions & 45 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,73 +475,58 @@ function DynamicPPL.setlogp!!(state::TuringState, logp)
return TuringState(setlogp!!(state.state, logp), logp)

Check warning on line 475 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L474-L475

Added lines #L474 - L475 were not covered by tests
end

# Some samplers use a VarInfo directly as the state. In that case, there's little to do in
# `setparams!!`.
function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVector)
return DynamicPPL.unflatten(state, params)
end

function AbstractMCMC.setparams!!(state::VarInfo, params::AbstractVarInfo)
return params
end
"""
setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo)
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model,
state::TuringState,
params::Union{AbstractVector,AbstractVarInfo},
)
new_inner_state = AbstractMCMC.setparams!!(model, state.state, params)
return TuringState(new_inner_state, state.logdensity)
end
A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an
`AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to
`AbstractMCMC.setparams!!(model, state, params[:])`.
# Unless some other treatment has been specified for this state type, just flatten the
# AbstractVarInfo. This method exists because some sampler types need to override this
# behavior.
function AbstractMCMC.setparams!!(model::DynamicPPL.Model, state, params::AbstractVarInfo)
`model` is typically a `DynamicPPL.Model`, but can also be e.g. an
`AbstractMCMC.LogDensityModel`.
"""
function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo)
return AbstractMCMC.setparams!!(model, state, params[:])

Check warning on line 489 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L488-L489

Added lines #L488 - L489 were not covered by tests
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::HMCState, params::AbstractVarInfo
# Some samplers use a VarInfo directly as the state. In that case, there's little to do in
# `setparams_varinfo!!`.
function setparams_varinfo!!(
model::DynamicPPL.Model, sampler::Sampler, state::VarInfo, params::AbstractVarInfo
)
θ_new = params[:]
hamiltonian = get_hamiltonian(model, state.sampler, params, state, length(θ_new))
return params
end

# Update the parameter values in `state.z`.
# TODO: Avoid mutation
z = state.z
resize!(z.θ, length(θ_new))
z.θ .= θ_new
return HMCState(
params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler
function setparams_varinfo!!(

Check warning on line 500 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L500

Added line #L500 was not covered by tests
model::DynamicPPL.Model, sampler::Sampler, state::TuringState, params::AbstractVarInfo
)
logdensity = DynamicPPL.setmodel(state.logdensity, model, sampler.alg.adtype)
new_inner_state = setparams_varinfo!!(

Check warning on line 504 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L503-L504

Added lines #L503 - L504 were not covered by tests
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
)
return TuringState(new_inner_state, logdensity)

Check warning on line 507 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L507

Added line #L507 was not covered by tests
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::HMCState, params::AbstractVector
function setparams_varinfo!!(

Check warning on line 510 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L510

Added line #L510 was not covered by tests
model::DynamicPPL.Model, sampler::Sampler, state::HMCState, params::AbstractVarInfo
)
θ_new = params
vi = DynamicPPL.unflatten(state.vi, params)
hamiltonian = get_hamiltonian(model, state.sampler, vi, state, length(θ_new))
θ_new = params[:]
hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new))

Check warning on line 514 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L513-L514

Added lines #L513 - L514 were not covered by tests

# Update the parameter values in `state.z`.
# TODO: Avoid mutation
z = state.z
resize!(z.θ, length(θ_new))
z.θ .= θ_new
return HMCState(vi, state.i, state.kernel, hamiltonian, z, state.adaptor, state.sampler)
return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor)

Check warning on line 521 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L518-L521

Added lines #L518 - L521 were not covered by tests
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::PGState, params::AbstractVarInfo
function setparams_varinfo!!(
model::DynamicPPL.Model, sampler::Sampler, state::PGState, params::AbstractVarInfo
)
return PGState(params, state.rng)
end

function AbstractMCMC.setparams!!(state::PGState, params::AbstractVector)
return PGState(DynamicPPL.unflatten(state.vi, params), state.rng)
end

function gibbs_step_inner(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
Expand All @@ -566,7 +551,9 @@ function gibbs_step_inner(

# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = AbstractMCMC.setparams!!(model_local, state_local, varinfo_local)
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
)
if gibbs_requires_recompute_logprob(
model_local, sampler_local, sampler_previous, state_local, state_previous
)
Expand Down
5 changes: 2 additions & 3 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ struct HMCState{
hamiltonian::THam
z::PhType
adaptor::TAdapt
sampler::Sampler{<:Hamiltonian}
end

###
Expand Down Expand Up @@ -230,7 +229,7 @@ function DynamicPPL.initialstep(
end

transition = Transition(model, vi, t)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor, spl)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)

return transition, state
end
Expand Down Expand Up @@ -276,7 +275,7 @@ function AbstractMCMC.step(

# Compute next transition and state.
transition = Transition(model, vi, t)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, spl)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)

return transition, newstate
end
Expand Down

0 comments on commit 508ac61

Please sign in to comment.