Skip to content

Commit

Permalink
Libtask Integration (#1766)
Browse files Browse the repository at this point in the history
* New Turing-libtask integration  (#1757)

* Update Project.toml

* Update Project.toml

* Update Project.toml

* trace down into functions calling produce

* trace into functions in testcases

* update to the latest version

* run tests against new libtask

* temporarily disable 1.3 for testing

* Update AdvancedSMC.jl

* Update AdvancedSMC.jl

* Update AdvancedSMC.jl

* Update AdvancedSMC.jl

* Update AdvancedSMC.jl

* copy Trace on tape

* Implement simplified evaluator for TracedModel

* Remove some unnecessary trace functions.

* Minor bugfix in TracedModel evaluator.

* Update .github/workflows/TuringCI.yml

* Minor bugfix in TracedModel evaluator.

* Update container.jl

* Update Project.toml

* Commented out tests related to control flow.  TuringLang/Libtask.jl/issues/96

* Commented out tests related to control flow.
TuringLang/Libtask.jl/issues/96

* Update Project.toml

* Update src/essential/container.jl

* Update AdvancedSMC.jl

Co-authored-by: KDr2 <[email protected]>

* CompatHelper: add new compat entry for Libtask at version 0.6 for package test, (keep existing compat) (#1765)

Co-authored-by: CompatHelper Julia <[email protected]>

* Fix for HMCs `dot_assume` (#1758)

* fixed dot_assume for hmc

* copy-pasted tests from dynamicppl integration tests

* inspecting what in the world is going on with tests

* trying again

* skip failing test for TrackerAD

* bump patch version

* fixed typo in tests

* Rename `Turing.Core` to `Turing.Essential`

* Deprecate Turing.Core

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* fixed a numerical test

* version bump

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>

* Minor fixes.

* Minor fixes.

* Minor fix.

* Update Julia version in CI

* Merge branch 'libtask-integration' of github.com:TuringLang/Turing.jl into libtask-integration

* Update Inference.jl

* Minor fixes.

* Add back `imm` test.

* Minor tweaks to make single
distribution tests more robust.

* Update Project.toml

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Update Project.toml

* Switch to StableRNGs for broken tests.

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Minor tweaks.

* Use StableRNG for GMM test.

* Update Project.toml

Co-authored-by: KDr2 <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: CompatHelper Julia <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
6 people authored Jan 31, 2022
1 parent 125b11c commit fca2fb0
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Julia
uses: julia-actions/setup-julia@v1
with:
version: '1.6'
version: '1'
- name: Set up Ruby 2.6
uses: actions/setup-ruby@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/DynamicHMC.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
strategy:
matrix:
version:
- '1.3'
- '1.6'
- '1'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Numerical.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
strategy:
matrix:
version:
- '1.3'
- '1.6'
- '1'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/StanCI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
strategy:
matrix:
version:
- '1.3'
- '1.6'
- '1'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/TuringCI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ jobs:
strategy:
matrix:
version:
- '1.3'
- '1.6'
- '1'
os:
- ubuntu-latest
arch:
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.19.5"
version = "0.20"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -37,7 +37,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractMCMC = "3.2"
AdvancedHMC = "0.3.0"
AdvancedMH = "0.6"
AdvancedPS = "0.2.4"
AdvancedPS = "0.3.3"
AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8, 0.9, 0.10"
Expand All @@ -48,7 +48,7 @@ DocStringExtensions = "0.8"
DynamicPPL = "0.17.2"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5.3"
Libtask = "0.6.6"
MCMCChains = "5"
NamedArrays = "0.9"
Reexport = "0.2, 1"
Expand All @@ -59,4 +59,4 @@ StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.3"
ZygoteRules = "0.2"
julia = "1.3, 1.4, 1.5, 1.6"
julia = "1.6"
37 changes: 30 additions & 7 deletions src/essential/container.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,43 @@
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model}
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple}
model::M
sampler::S
varinfo::V
evaluator::E
end

# needed?
function TracedModel{SampleFromPrior}(
function TracedModel(
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
)
return TracedModel(model, SampleFromPrior(), varinfo)
)
# evaluate!!(m.model, varinfo, SamplingContext(Random.AbstractRNG, m.sampler, DefaultContext()))
context = SamplingContext(DynamicPPL.Random.GLOBAL_RNG, sampler, DefaultContext())
evaluator = _get_evaluator(model, varinfo, context)
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
end

(f::TracedModel)() = f.model(f.varinfo, f.sampler)
# Smiliar to `evaluate!!` except that we return the evaluator signature without excutation.
# TODO: maybe move to DynamicPPL
@generated function _get_evaluator(
model::Model{_F,argnames}, varinfo, context
) where {_F,argnames}
unwrap_args = [
:($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
]
# We want to give `context` precedence over `model.context` while also
# preserving the leaf context of `context`. We can do this by
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
# 2. Set leaf context of `context` to the context resulting from (1).
# The result is:
# `context` -> `childcontext(context)` -> ... -> `model.context`
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
return quote
context_new = DynamicPPL.setleafcontext(
context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context))
)
(model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...))
end
end

function Base.copy(trace::AdvancedPS.Trace{<:TracedModel})
f = trace.f
Expand Down Expand Up @@ -46,4 +70,3 @@ function AdvancedPS.reset_logprob!(f::TracedModel)
DynamicPPL.resetlogp!!(f.varinfo)
return
end

14 changes: 12 additions & 2 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,19 @@ function DynamicPPL.assume(
spl::Sampler{<:Union{PG,SMC}},
dist::Distribution,
vn::VarName,
::Any
__vi__::AbstractVarInfo
)
vi = AdvancedPS.current_trace().f.varinfo
local vi
try
vi = AdvancedPS.current_trace().f.varinfo
catch e
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
if e == KeyError(:__trace) || current_task().storage isa Nothing
vi = __vi__
else
rethrow(e)
end
end
if inspace(vn, spl)
if ~haskey(vi, vn)
r = rand(rng, dist)
Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
AbstractMCMC = "3.2.1"
AdvancedMH = "0.6"
AdvancedPS = "0.2"
AdvancedPS = "0.3"
AdvancedVI = "0.1"
Clustering = "0.14"
CmdStan = "6.0.8"
Expand All @@ -53,4 +54,5 @@ StatsBase = "0.33"
StatsFuns = "0.9.5"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
StableRNGs = "1"
julia = "1.3"
18 changes: 9 additions & 9 deletions test/contrib/inference/sghmc.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
@testset "sghmc.jl" begin
@numerical_testset "sghmc inference" begin
Random.seed!(125)

alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5)
chain = sample(gdemo_default, alg, 10_000)
check_gdemo(chain, atol = 0.1)
end
@turing_testset "sghmc constructor" begin
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1)
@test alg isa SGHMC
Expand All @@ -22,6 +15,13 @@
sampler = Turing.Sampler(alg)
@test sampler isa Turing.Sampler{<:SGHMC}
end
@numerical_testset "sghmc inference" begin
rng = StableRNG(123)

alg = SGHMC(; learning_rate=0.02, momentum_decay=0.5)
chain = sample(rng, gdemo_default, alg, 10_000)
check_gdemo(chain, atol = 0.1)
end
end

@testset "sgld.jl" begin
Expand All @@ -42,9 +42,9 @@ end
@test sampler isa Turing.Sampler{<:SGLD}
end
@numerical_testset "sgld inference" begin
Random.seed!(125)
rng = StableRNG(1)

chain = sample(gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 10_000)
chain = sample(rng, gdemo_default, SGLD(; stepsize = PolynomialStepsize(0.5)), 20_000)
check_gdemo(chain, atol = 0.2)

# Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh)
Expand Down
21 changes: 11 additions & 10 deletions test/inference/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@turing_testset "gibbs conditionals" begin
Random.seed!(100)
@turing_testset "gibbs conditionals.jl" begin
Random.seed!(1000); rng = StableRNG(123)

@turing_testset "gdemo" begin
# We consider the model
Expand Down Expand Up @@ -40,7 +40,7 @@
GibbsConditional(:m, cond_m),
GibbsConditional(:s, _ -> Normal(s_posterior_mean, 0)),
)
chain = sample(gdemo_default, sampler1, 10_000)
chain = sample(rng, gdemo_default, sampler1, 10_000)
cond_m_mean = mean(cond_m((s = s_posterior_mean,)))
check_numerical(chain, [:m, :s], [cond_m_mean, s_posterior_mean])
@test all(==(s_posterior_mean), chain[:s][2:end])
Expand All @@ -50,18 +50,19 @@
GibbsConditional(:m, _ -> Normal(m_posterior_mean, 0)),
GibbsConditional(:s, cond_s),
)
chain = sample(gdemo_default, sampler2, 10_000)
chain = sample(rng, gdemo_default, sampler2, 10_000)
cond_s_mean = mean(cond_s((m = m_posterior_mean,)))
check_numerical(chain, [:m, :s], [m_posterior_mean, cond_s_mean])
@test all(==(m_posterior_mean), chain[:m][2:end])

# and one for both using the conditional
sampler3 = Gibbs(GibbsConditional(:m, cond_m), GibbsConditional(:s, cond_s))
chain = sample(gdemo_default, sampler3, 10_000)
chain = sample(rng, gdemo_default, sampler3, 10_000)
check_gdemo(chain)
end

@turing_testset "GMM" begin
Random.seed!(1000); rng = StableRNG(123)
# We consider the model
# ```math
# μₖ ~ Normal(m, σ_μ), k = 1, …, K,
Expand All @@ -77,9 +78,9 @@
N = 20 # number of observations

# We generate data
μ_data = rand(Normal(m, sqrt(σ²_μ)), K)
z_data = rand(Categorical(π), N)
x_data = rand(MvNormal(μ_data[z_data], σ²_x * I))
μ_data = rand(rng, Normal(m, sqrt(σ²_μ)), K)
z_data = rand(rng, Categorical(π), N)
x_data = rand(rng, MvNormal(μ_data[z_data], σ²_x * I))

@model function mixture(x)
μ ~ $(MvNormal(fill(m, K), σ²_μ * I))
Expand Down Expand Up @@ -132,14 +133,14 @@
sampler2 = Gibbs(GibbsConditional(:z, cond_z), MH())
sampler3 = Gibbs(GibbsConditional(:z, cond_z), HMC(0.01, 7, ))
for sampler in (sampler1, sampler2, sampler3)
chain = sample(model, sampler, 10_000)
chain = sample(rng, model, sampler, 10_000)

μ_hat = estimate(chain, )
lμ_hat, uμ_hat = extrema(μ_hat)
@test isapprox([lμ_data, uμ_data], [lμ_hat, uμ_hat], atol=0.1)

z_hat = estimatez(chain, :z, 1:2)
ari, _, _, _ = randindex(z_data, Int.(z_hat))
ari, _, _, _ = Clustering.randindex(z_data, Int.(z_hat))
@test isapprox(ari, 1, atol=0.1)
end
end
Expand Down
Loading

0 comments on commit fca2fb0

Please sign in to comment.