Skip to content

Commit

Permalink
Use values_as_in_model to extract the parameters from a Transition (
Browse files Browse the repository at this point in the history
#2202)

* use `values_as_in_model` to extract the parameters from a `Transition`
rather than `invlink` + `values_as`

* bump BangBang compat entry

* added test from #2195 + added HypothesisTests.jl so we can compare
chains properly

* deepcopy varinfo before calling `values_as_in_model` to avoid mutating
the original logprob computations, etc.

* bump patch version

* fixed tests

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai authored May 7, 2024
1 parent 8ced922 commit 56f64ec
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.31.3"
version = "0.31.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
14 changes: 9 additions & 5 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,15 @@ Return a named tuple of parameters.
"""
getparams(model, t) = t.θ
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# Want the end-user to receive parameters in constrained space, so we `link`.
vi = DynamicPPL.invlink(vi, model)

# Extract parameter values in a simple form from the `VarInfo`.
vals = DynamicPPL.values_as(vi, OrderedDict)
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
# of the parameters change depending on the realizations. Hence we have to use
# `values_as_in_model`, which re-runs the model and extracts the parameters
# as they are seen in the model, i.e. in the constrained space. Moreover,
# this means that the code below will work both of linked and invlinked `vi`.
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))

# Obtain an iterator over the flattened parameter names and values.
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down Expand Up @@ -43,6 +44,7 @@ DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.25.1"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "5, 6"
Expand Down
42 changes: 42 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,46 @@
@test mean(Array(chain)) 0.2
end
end

@turing_testset "issue: #2195" begin
@model function buggy_model()
lb ~ Uniform(0, 1)
ub ~ Uniform(1.5, 2)

# HACK: Necessary to avoid NUTS failing during adaptation.
try
x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
catch e
if e isa DomainError
Turing.@addlogprob! -Inf
return nothing
else
rethrow()
end
end
end

model = buggy_model();
num_samples = 1_000;

chain = sample(
model,
NUTS(),
num_samples;
initial_params=[0.5, 1.75, 1.0]
)
chain_prior = sample(model, Prior(), num_samples)

# Extract the `x` like this because running `generated_quantities` was how
# the issue was discovered, hence we also want to make sure that it works.
results = generated_quantities(model, chain)
results_prior = generated_quantities(model, chain_prior)

# Make sure none of the samples in the chains resulted in errors.
@test all(!isnothing, results)

# The discrepancies in the chains are in the tails, so we can't just compare the mean, etc.
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.01
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using ReverseDiff
using SpecialFunctions
using StatsBase
using StatsFuns
using HypothesisTests
using Tracker
using Turing
using Turing.Inference
Expand Down

2 comments on commit 56f64ec

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Features

  • It is now possible to use := to track deterministic quantities in the model.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/106363

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.31.4 -m "<description of version>" 56f64ec5909cec4a5ded4e28555c2b289020bbe1
git push origin v0.31.4

Please sign in to comment.