diff --git a/Project.toml b/Project.toml index f35b1df4..41e11f26 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.3" +version = "0.6.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index fcaab095..7d7773ed 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -92,13 +92,13 @@ MassMatrixAdaptor(m::DiagEuclideanMetric{T}) where {T} = MassMatrixAdaptor(m::DenseEuclideanMetric{T}) where {T} = WelfordCov{T}(size(m); cov = copy(m.M⁻¹)) -MassMatrixAdaptor(m::Type{TM}, sz::Tuple{Vararg{Int}} = (2,)) where {TM<:AbstractMetric} = - MassMatrixAdaptor(Float64, m, sz) +MassMatrixAdaptor(::Type{TM}, sz::Dims = (2,)) where {TM<:AbstractMetric} = + MassMatrixAdaptor(Float64, TM, sz) MassMatrixAdaptor( ::Type{T}, ::Type{TM}, - sz::Tuple{Vararg{Int}} = (2,), + sz::Dims = (2,), ) where {T,TM<:AbstractMetric} = MassMatrixAdaptor(TM(T, sz)) # Deprecations diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 40585909..9403bffe 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -308,11 +308,12 @@ function make_initial_params( initial_params, ) T = sampler_eltype(spl) - if initial_params == nothing + if initial_params === nothing d = LogDensityProblems.dimension(logdensity) - initial_params = randn(rng, d) + return randn(rng, T, d) + else + return T.(initial_params) end - return T.(initial_params) end ######### @@ -342,10 +343,10 @@ end function make_step_size( rng::Random.AbstractRNG, integrator::AbstractIntegrator, - T::Type, + ::Type{T}, hamiltonian::Hamiltonian, initial_params, -) +) where {T} if integrator.ϵ > 0 ϵ = integrator.ϵ else @@ -358,10 +359,10 @@ end function make_step_size( rng::Random.AbstractRNG, integrator::Symbol, - T::Type, + ::Type{T}, hamiltonian::Hamiltonian, initial_params, -) +) where {T} ϵ = find_good_stepsize(rng, hamiltonian, initial_params) @info string("Found initial step size ", ϵ) return T(ϵ) @@ -370,21 +371,33 @@ end make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ) make_integrator(i::AbstractIntegrator, ϵ::Real) = i -make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ) -make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.") -make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ) -make_integrator(i::Val{:jitteredleapfrog}, ϵ::T) where {T<:Real} = - JitteredLeapfrog(ϵ, T(0.1ϵ)) -make_integrator(i::Val{:temperedleapfrog}, ϵ::T) where {T<:Real} = TemperedLeapfrog(ϵ, T(1)) +function make_integrator(i::Symbol, ϵ::Real) + float_ϵ = AbstractFloat(ϵ) + if i === :leapfrog + return Leapfrog(float_ϵ) + elseif i === :jitteredleapfrog + return JitteredLeapfrog(float_ϵ, float_ϵ / 10) + elseif i === :temperedleapfrog + return TemperedLeapfrog(float_ϵ, oneunit(float_ϵ)) + else + error("Integrator $i not supported.") + end +end ######### -make_metric(@nospecialize(i), T::Type, d::Int) = error("Metric $(typeof(i)) not supported.") -make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d) -make_metric(i::AbstractMetric, T::Type, d::Int) = i -make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d) -make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d) -make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d) +make_metric(i::AbstractMetric, ::Type, ::Int) = i +function make_metric(i::Symbol, ::Type{T}, d::Int) where {T} + if i === :diagonal + return DiagEuclideanMetric(T, d) + elseif i === :unit + return UnitEuclideanMetric(T, d) + elseif i === :dense + return DenseEuclideanMetric(T, d) + else + error("Metric $i not supported.") + end +end function make_metric(spl::AbstractHMCSampler, logdensity) d = LogDensityProblems.dimension(logdensity) diff --git a/src/metric.jl b/src/metric.jl index 2afbd629..2af44e70 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -92,12 +92,6 @@ Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...) Base.show(io::IO, dem::DenseEuclideanMetric) = print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))") -# getname functions -for T in (UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric) - @eval getname(::Type{<:$T}) = $T -end -getname(m::T) where {T<:AbstractMetric} = getname(T) - # `rand` functions for `metric` types. function _rand(