Skip to content

Commit

Permalink
Default float type to float(Real), not Real (TuringLang#685) (TuringL…
Browse files Browse the repository at this point in the history
…ang#686)

* Default float type to float(Real), not Real (TuringLang#685)

* Default float type to float(Real), not Real

Closes TuringLang#684

* Trigger CI on backport branches/PRs

* Add integration test for TuringLang#684

* Bump Turing version to 0.34 in test subfolder

* Bump minimum Julia version to 1.10

* Bump patch version

* Bump patch again
  • Loading branch information
penelopeysm authored Nov 7, 2024
1 parent 09e997b commit d6e2147
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ on:
push:
branches:
- master
- backport-*
pull_request:
branches:
- master
- backport-*
merge_group:
types: [checks_requested]

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Requires = "1"
ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "~1.6.6, 1.7.3"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,10 +941,10 @@ end
"""
float_type_with_fallback(x)
Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`.
Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`.
"""
float_type_with_fallback(::Type) = Real
float_type_with_fallback(::Type{Union{}}) = Real
float_type_with_fallback(::Type) = float(Real)
float_type_with_fallback(::Type{Union{}}) = float(Real)
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)

"""
Expand Down
15 changes: 15 additions & 0 deletions test/turing/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,19 @@
model = state_space(y, length(t))
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n
end

if Threads.nthreads() > 1
@testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin
@model function f(x)
ns ~ filldist(Normal(0, 2.0), 3)
m ~ Uniform(0, 1)
return x ~ Normal(m, 1)
end
model = f(1)
chain = sample(model, NUTS(), MCMCThreads(), 10, 2)
loglikelihood(model, chain)
logprior(model, chain)
logjoint(model, chain)
end
end
end

0 comments on commit d6e2147

Please sign in to comment.