From 053a2f700a6161672e53999ce1a4a4bd5c5e0cfe Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 13 Dec 2020 09:14:10 +1100 Subject: [PATCH] Twin PR to DPPL 191 (#1465) Co-authored-by: David Widmann --- Project.toml | 4 ++-- src/inference/Inference.jl | 30 ------------------------------ test/Project.toml | 2 +- 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/Project.toml b/Project.toml index 57ea7a300..6b4c29ca2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.15.2" +version = "0.15.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -43,7 +43,7 @@ Bijectors = "0.8" Distributions = "0.23.3, 0.24" DistributionsAD = "0.6" DocStringExtensions = "0.8" -DynamicPPL = "0.10.0" +DynamicPPL = "0.10.2" EllipticalSliceSampling = "0.3" ForwardDiff = "0.10.3" Libtask = "0.4, 0.5" diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 712f35c18..3ad81f79f 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -429,36 +429,6 @@ for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) @eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space end -floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) -floatof(::Type) = Real # fallback if type inference failed - -function get_matching_type( - spl::AbstractSampler, - vi, - ::Type{T}, -) where {T} - return T -end -function get_matching_type( - spl::AbstractSampler, - vi, - ::Type{<:Union{Missing, AbstractFloat}}, -) - return Union{Missing, floatof(eltype(vi, spl))} -end -function get_matching_type( - spl::AbstractSampler, - vi, - ::Type{<:AbstractFloat}, -) - return floatof(eltype(vi, spl)) -end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(spl, vi, T), N} -end -function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where T - return Array{get_matching_type(spl, vi, T)} -end function get_matching_type( spl::Sampler{<:Union{PG, SMC}}, vi, diff --git a/test/Project.toml b/test/Project.toml index cbf225e71..6887aeac8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,7 +37,7 @@ CmdStan = "6.0.8" Distributions = "0.23.8, 0.24" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6" -DynamicPPL = "0.10.0" +DynamicPPL = "0.10.2" FiniteDifferences = "0.10.8, 0.11" ForwardDiff = "0.10.12" MCMCChains = "4.0.4"