Skip to content

Commit

Permalink
Retain reference particle properly (#1599)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 2, 2021
1 parent 46c0589 commit 065c676
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.15.18"
version = "0.15.19"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
10 changes: 7 additions & 3 deletions src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,21 @@ function AbstractMCMC.step(
)
# Reset the VarInfo before new sweep.
reset_num_produce!(vi)
set_retained_vns_del_by_spl!(vi, spl)
resetlogp!(vi)

# Create reference particle for which the samples will be retained.
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi))

# For all other particles, do not retain the variables but resample them.
set_retained_vns_del_by_spl!(vi, spl)

# Create a new set of particles.
num_particles = spl.alg.nparticles
x = map(1:num_particles) do i
if i != num_particles
return AdvancedPS.Trace(model, spl, vi)
else
# Create reference particle.
return AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi))
return reference
end
end
particles = AdvancedPS.ParticleContainer(x)
Expand Down
7 changes: 7 additions & 0 deletions test/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ end
@test all(isone, chains_pg[:x])
@test chains_pg.logevidence -2 * log(2) atol = 0.01
end

# https://github.com/TuringLang/Turing.jl/issues/1598
@turing_testset "reference particle" begin
c = sample(gdemo_default, PG(1), 1_000)
@test length(unique(c[:m])) == 1
@test length(unique(c[:s])) == 1
end
end

# @testset "pmmh.jl" begin
Expand Down
22 changes: 11 additions & 11 deletions test/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

alg1 = HMCDA(1000, 0.65, 0.15)
alg2 = PG(20)
alg3 = Gibbs(PG(30, :s), HMCDA(500, 0.65, 0.05, :m))
alg3 = Gibbs(PG(30, :s), HMC(0.2, 4, :m))

chn1 = sample(gdemo_default, alg1, 5000; save_state=true)
check_gdemo(chn1)
Expand All @@ -73,31 +73,31 @@
chn2_contd = sample(gdemo_default, alg2, 1000; resume_from=chn2)
check_gdemo(chn2_contd)

chn3 = sample(gdemo_default, alg3, 1000; save_state=true)
chn3 = sample(gdemo_default, alg3, 5000; save_state=true)
check_gdemo(chn3)

chn3_contd = sample(gdemo_default, alg3, 1000; resume_from=chn3)
check_gdemo(chn3_contd; atol=0.25)
check_gdemo(chn3_contd)
end
@testset "Contexts" begin
# Test LikelihoodContext
@model testmodel(x) = begin
@model function testmodel1(x)
a ~ Beta()
lp1 = getlogp(_varinfo)
x[1] ~ Bernoulli(a)
global loglike = getlogp(_varinfo) - lp1
lp1 = getlogp(__varinfo__)
x[1] ~ Bernoulli(a)
global loglike = getlogp(__varinfo__) - lp1
end
model = testmodel([1.0])
model = testmodel1([1.0])
varinfo = Turing.VarInfo(model)
model(varinfo, Turing.SampleFromPrior(), Turing.LikelihoodContext())
@test getlogp(varinfo) == loglike

# Test MiniBatchContext
@model testmodel(x) = begin
@model function testmodel2(x)
a ~ Beta()
x[1] ~ Bernoulli(a)
x[1] ~ Bernoulli(a)
end
model = testmodel([1.0])
model = testmodel2([1.0])
varinfo1 = Turing.VarInfo(model)
varinfo2 = deepcopy(varinfo1)
model(varinfo1, Turing.SampleFromPrior(), Turing.LikelihoodContext())
Expand Down
4 changes: 2 additions & 2 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
alg = Gibbs(
CSMC(10, :s),
HMC(0.2, 4, :m))
chain = sample(gdemo(1.5, 2.0), alg, 3000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)
chain = sample(gdemo(1.5, 2.0), alg, 1_500)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.15)

Random.seed!(100)

Expand Down

2 comments on commit 065c676

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/35824

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.15.19 -m "<description of version>" 065c67628133b1765c05138654d19b5d9df245ab
git push origin v0.15.19

Please sign in to comment.