diff --git a/Project.toml b/Project.toml index 5bf0cc690..af13bce1d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.15.17" +version = "0.15.18" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/contrib/inference/sghmc.jl b/src/contrib/inference/sghmc.jl index 52010209c..d026b3656 100644 --- a/src/contrib/inference/sghmc.jl +++ b/src/contrib/inference/sghmc.jl @@ -41,9 +41,9 @@ function SGHMC{AD}( return SGHMC{AD,space,typeof(_learning_rate)}(_learning_rate, _momentum_decay) end -struct SGHMCState{V<:AbstractVarInfo} +struct SGHMCState{V<:AbstractVarInfo, T<:AbstractVector{<:Real}} vi::V - velocity::Vector{Float64} + velocity::T end function DynamicPPL.initialstep( @@ -84,7 +84,7 @@ function AbstractMCMC.step( θ .+= v η = spl.alg.learning_rate α = spl.alg.momentum_decay - newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, length(v)) + newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) # Save new variables and recompute log density. vi[spl] = θ @@ -229,7 +229,7 @@ function AbstractMCMC.step( _, grad = gradient_logp(θ, vi, model, spl) step = state.step stepsize = spl.alg.stepsize(step) - θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, length(θ)) + θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) # Save new variables and recompute log density. vi[spl] = θ