Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Perform invlinking in assume rather than implicitly in getindex #360

Closed
wants to merge 107 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
14594d6
performing linking in assume rather than implicitly in getindex
torfjelde Jan 7, 2022
0bc279f
added istrans to SimpleVarInfo
torfjelde Jan 7, 2022
81ee12e
Apply suggestions from code review
torfjelde Jan 7, 2022
d39f87d
added a comment
torfjelde Jan 7, 2022
23f34cc
bump patch version
torfjelde Jan 7, 2022
d3ec108
introduced settrans!!
torfjelde Jan 9, 2022
81782c9
added istrans(vi) and renamed all occurences of trans! to trans!!
torfjelde Jan 9, 2022
12bfb42
exclusively use settrans!! to set the istrans for SimpleVarInfo
torfjelde Jan 10, 2022
c2c2417
removed usage of deprecated method in turing tests
torfjelde Jan 10, 2022
2e2cb5c
added docstring to settrans!!
torfjelde Jan 13, 2022
f6c3fc4
include istrans flag in type of SimpleVarInfo instead
torfjelde Jan 27, 2022
d643a78
deprecated settrans! in favour of settrans!!
torfjelde Jan 27, 2022
3cab7d9
added some tests specifically for istrans
torfjelde Jan 27, 2022
8b870dc
formatting
torfjelde Jan 27, 2022
0b304db
fixed bugs for ThreadSafeVarInfo
torfjelde Jan 27, 2022
b146b11
additional constructor for SimpleVarInfo
torfjelde Jan 27, 2022
d170d92
Update src/DynamicPPL.jl
torfjelde Feb 9, 2022
f46183b
added ConstructionBase.jl as dep
torfjelde Feb 9, 2022
7a78eec
added constraint types and doctests
torfjelde Feb 9, 2022
70b3b70
added DocStringExtensions as a dep
torfjelde Feb 9, 2022
a03e8cf
formatting
torfjelde Feb 9, 2022
793c931
remove redundant maybe_link
torfjelde Feb 9, 2022
a9b12fd
fixed typo
torfjelde Feb 9, 2022
b1d7f9a
Merge branch 'master' into tor/link-improvements
yebai Feb 9, 2022
be98961
moved a docstring
torfjelde Feb 11, 2022
3e1588b
fixed bug in tets
torfjelde Feb 11, 2022
7697fce
Merge branch 'tor/link-improvements' of github.com:TuringLang/Dynamic…
torfjelde Feb 11, 2022
3139c62
version bump
torfjelde Feb 11, 2022
3610658
added missing istrans impl
torfjelde Feb 12, 2022
27171ad
fixed bug with istrans
torfjelde Feb 13, 2022
cd2d9d6
fixed issue with getindex_raw for VarInfo
torfjelde Feb 13, 2022
d948cb9
Update src/varinfo.jl
torfjelde Feb 13, 2022
6a3e18f
Merge branch 'master' into tor/link-improvements
torfjelde Jun 10, 2022
d674478
Merge branch 'master' into tor/link-improvements
torfjelde Jun 17, 2022
26d2dbb
getindex of varinfo implementations now optionally takes a Distributi…
torfjelde Jun 22, 2022
3fcba56
use get_index_raw with dist argument
torfjelde Jun 22, 2022
83a9448
added missing assume implementations for SimpleVarInfo
torfjelde Jun 22, 2022
356fa9c
fixed settrans!! for VarInfo
torfjelde Jun 22, 2022
13f037f
formatting
torfjelde Jun 24, 2022
c7544e0
fixed bug where constrained/unconstrained wasn't preserved in setinde…
torfjelde Jun 24, 2022
d1dccf1
hack to avoid type-instabilities for dot_assume with MultivariateDist…
torfjelde Jun 24, 2022
ff7ff4a
style
torfjelde Jun 26, 2022
2f1a2ff
added keys implementations for the models in TestUtils to make testin…
torfjelde Jun 26, 2022
d6311b7
added additional test model which uses dot-assume on MultivariateDist…
torfjelde Jun 26, 2022
ed2fa69
updated tests for SimpleVarInfo
torfjelde Jun 26, 2022
a82be56
added a no-op reconstruct for UnivariateDistribution
torfjelde Jun 26, 2022
7aacee5
fixed tests for loglikelihoods
torfjelde Jun 27, 2022
96f128f
fixed dot_tilde_assume for LikelihoodContext
torfjelde Jun 27, 2022
2e88d08
removed some now redundant explicit calls to maybe_invlink
torfjelde Jun 27, 2022
0f9765b
added impls of size and length for the wrapper distributions so they …
torfjelde Jun 27, 2022
116c95c
bumped version
torfjelde Jun 28, 2022
d797e99
removed redunant explict call to maybe_invlink
torfjelde Jun 28, 2022
44b2f66
added test model with array on RHS of a .~ statement
torfjelde Jun 29, 2022
81cd881
improved some of the default implementations of dot_assume
torfjelde Jun 29, 2022
2e14abd
removed unnecessary code in tests
torfjelde Jun 29, 2022
12adc83
improved linking usage in assumes for SimpleVarInfo
torfjelde Jun 29, 2022
af3e6ba
Merge branch 'master' into tor/link-improvements
yebai Jun 29, 2022
f7501df
added model for testing dynamic constraints
torfjelde Jun 30, 2022
abcabf4
added logjoint_true_with_logabsdet_jacobian to TestUtils
torfjelde Jun 30, 2022
fdee509
added test for dynamic constraints for SimpleVarInfo
torfjelde Jun 30, 2022
e974c83
fixed keys implementation of SimpleVarInfo
torfjelde Jun 30, 2022
6c6d5f5
reverted unintended change
torfjelde Jun 30, 2022
801bd4c
renamed Base.keys(model) to varnames(model) in TestUtils
torfjelde Jul 1, 2022
46f6f4c
added default implementation and docstring for TestUtils.varnames
torfjelde Jul 1, 2022
bcb767b
replace handwritten by DocStringExtensions
torfjelde Jul 1, 2022
c5be1c2
Apply suggestions from @devmotion
torfjelde Jul 1, 2022
f266929
Update src/context_implementations.jl
torfjelde Jul 1, 2022
c2dbbaf
removed some asserts and use broadcast instead of map
torfjelde Jul 1, 2022
1abb46c
replace map with broadcasting to ensure consistent behavior
torfjelde Jul 1, 2022
1086c6c
Update src/simple_varinfo.jl
torfjelde Jul 1, 2022
f2fb4a5
added a method nodist to allow broadcasting NoDist constructor
torfjelde Jul 1, 2022
490d24e
updated some tests
torfjelde Jul 1, 2022
6350ccd
renamed AbstractConstraint to AbstractTransformation and its subtypes
torfjelde Jul 1, 2022
951e4c3
updated tests
torfjelde Jul 1, 2022
dcd92c9
fixed nodist usage
torfjelde Jul 1, 2022
2922ffa
fixed implementation of nodist
torfjelde Jul 1, 2022
5266a4b
fixed typo
torfjelde Jul 1, 2022
3c38710
formatting
torfjelde Jul 1, 2022
ba92f3f
bump patch version
torfjelde Jul 1, 2022
70c864c
fixed ThreadsafeVarInfo
torfjelde Jul 1, 2022
66f41a9
Apply suggestions from code review
torfjelde Jul 1, 2022
eb2d6b5
allow type-stable settrans!! for SimpleVarInfo
torfjelde Jul 1, 2022
e8cdb91
use maybe_invlink in getindex for VarInfo
torfjelde Jul 1, 2022
359d384
added comment to warn about buggy behavior
torfjelde Jul 1, 2022
ab0a99b
Update src/context_implementations.jl
torfjelde Jul 1, 2022
dd10913
just fix potential bug in getindex for VarInfo
torfjelde Jul 1, 2022
18d28cc
revert previous change because it likely introduces bugs
torfjelde Jul 1, 2022
32b7aab
elaborate in comment regarding potential bug
torfjelde Jul 1, 2022
fb86231
Merge branch 'tor/link-improvements' of github.com:TuringLang/Dynamic…
torfjelde Jul 1, 2022
f782fe2
added error message to dot_assume
torfjelde Jul 1, 2022
7d3493d
added error message to dot_assume again
torfjelde Jul 1, 2022
f0f981b
added _protect_dists method to help with broadcasting of NoDist
torfjelde Jul 2, 2022
1e0b946
simplified show for SimpleVarInfo
torfjelde Jul 2, 2022
faa0e42
styling
torfjelde Jul 2, 2022
9e7f493
fixed bug in show for SimpleVarInfo
torfjelde Jul 3, 2022
0a9383b
Revert "added _protect_dists method to help with broadcasting of NoDist"
torfjelde Jul 3, 2022
d8b0a75
fixed getindex with vector of varnames for AbstractVarInfo
torfjelde Jul 3, 2022
400f90f
Improvements to TestUtils (follow-up from #360) (#415)
torfjelde Jul 20, 2022
9241acd
fixed tests for distribution_wrappers
torfjelde Jul 21, 2022
947e5c6
upper bound Distributions because tests are sooooo slow due to deprec…
torfjelde Jul 21, 2022
6f9be0d
Update bors.toml
yebai Jul 21, 2022
5c5b9ce
Revert "Update bors.toml"
torfjelde Jul 22, 2022
e0797cc
Revert "upper bound Distributions because tests are sooooo slow due t…
torfjelde Jul 22, 2022
4b0e0e1
switch of deprecation warnings from integration tests for now
torfjelde Jul 22, 2022
5a73c87
bump supported Julia version to 1.6
torfjelde Jul 23, 2022
bb43021
added ability to filter varnames to check in TestUtils.test_sampler
torfjelde Jul 23, 2022
d5a48f8
bump minor version
torfjelde Jul 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.17.3"
version = "0.17.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
75 changes: 49 additions & 26 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
Expand All @@ -64,15 +64,15 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
Expand All @@ -86,7 +86,7 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
Expand Down Expand Up @@ -194,7 +194,9 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn]
# x = vi[vn]
r_raw = getindex_raw(vi, vn)
r = maybe_invlink(vi, vn, dist, r_raw)
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
Copy link
Member

@yebai yebai Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'll be nice if we can overload the transform used inside logpdf_with_trans in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually replace the above with

logp = zero(eltype(vi))
if istrans(vi, vn)
    r, logjac = forward(bijector(dist))
    logp -= logjac
else
    r = r_raw
end

return r, logpdf(dist, r)

right now. But I think if we do this we should remove logpdf_with_trans everywhere, so I wanted to defer that to another PR.

Copy link
Member

@yebai yebai Feb 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phipsgabler this is what I mentioned today. We can now replace logpdf_with_trans with ChangeOfVariables/LogDensityInterface API.

#342

end

Expand All @@ -211,16 +213,23 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r))
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
# Otherwise we just extract it.
# r = vi[vn]
r_raw = getindex_raw(vi, vn)
r = maybe_invlink(vi, vn, dist, r_raw)
end
else
r = init(rng, dist, sampler)
push!!(vi, vn, r, dist, sampler)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
Expand Down Expand Up @@ -286,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
Expand All @@ -305,7 +314,7 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
Expand All @@ -326,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
Expand All @@ -345,7 +354,7 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.(Ref(vi), false, _vns)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
Expand Down Expand Up @@ -390,7 +399,9 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = vi[vns]
# r = vi[vns]
r_raw = getindex_raw(vi, vns)
r = maybe_invlink(vi, vn, dist, r_raw)
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
Expand Down Expand Up @@ -423,7 +434,8 @@ function dot_assume(
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = reshape(vi[vec(vns)], size(vns))
r_raw = getindex_raw(vi, vec(vns))
r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
return r, lp, vi
end
Expand Down Expand Up @@ -462,19 +474,24 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r_raw = getindex_raw(vi, vns)
r = maybe_invlink(vi, vns, dist, r_raw)
end
else
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
push!!(vi, vn, r[:, i], dist, spl)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# `push!!` sets the trans-flag to `false` by default.
setttrans!!(vi, true, vn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be another confusing design choice but probably should not be addressed in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above.

else
push!!(vi, vn, r[:, i], dist, spl)
end
end
end
return r
Expand All @@ -496,12 +513,13 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
# r = reshape(vi[vec(vns)], size(vns))
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns)))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -511,8 +529,13 @@ function get_and_set_val!(
# 1. Figure out the broadcast size and use a `foreach`.
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
push!!.(Ref(vi), vns, r, dists, Ref(spl))
settrans!.(Ref(vi), false, vns)
if istrans(vi)
push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.(Ref(vi), true, vns)
else
push!!.(Ref(vi), vns, r, dists, Ref(spl))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
end
return r
end
Expand Down
103 changes: 48 additions & 55 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,20 @@ ERROR: type NamedTuple has no field b
struct SimpleVarInfo{NT,T} <: AbstractVarInfo
values::NT
logp::T
# TODO: Should we put this in the type instead?
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
istrans::Bool
end

SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T))
SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs))
SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs))
SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, false)
function SimpleVarInfo{T}(θ) where {T<:Real}
return SimpleVarInfo{typeof(θ),T}(θ, zero(T), false)
end
function SimpleVarInfo{T}(; kwargs...) where {T<:Real}
return SimpleVarInfo{T}(NamedTuple(kwargs))
end
function SimpleVarInfo(; kwargs...)
return SimpleVarInfo{Float64}(NamedTuple(kwargs))
end
SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ)

# Constructor from `Model`.
Expand All @@ -158,8 +167,8 @@ function BangBang.empty!!(vi::SimpleVarInfo)
end

getlogp(vi::SimpleVarInfo) = vi.logp
setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp)
acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp)
setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp
acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp
Comment on lines +235 to +236
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we gain something from not just defining ... = SimpleVarInfo(vi.values, logp, vi.transformation)? Seems slightly overkill here 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I just went with @set because at some point I forgot to update the method to also carry the vi.transformation, which lead to silent conversion to "unconstrained" SimpleVarInfo. So I figured I might as well just have Setfield handle this for me going forward to avoid such silly mistakes 🤷 It should be zero overhead anyways, right? Or are you worried about compile-times?


"""
keys(vi::SimpleVarInfo)
Expand All @@ -179,7 +188,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
end

function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo)
return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")")
return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ", ", svi.istrans, ")")
end

# `NamedTuple`
Expand Down Expand Up @@ -224,6 +233,11 @@ Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values
# TODO: Should we do better?
Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values

# Since we don't perform any transformations in `getindex` for `SimpleVarInfo`
# we simply call `getindex` in `getindex_raw`.
getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn]
getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns]

Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn)
function _haskey(nt::NamedTuple, vn::VarName)
# LHS: Ensure that `nt` indeed has the property we want.
Expand Down Expand Up @@ -337,58 +351,21 @@ function Base.eltype(
end

# Context implementations
function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple)
left = vi[vn]
return left, Distributions.loglikelihood(dist, left), vi
end

# NOTE: Evaluations, i.e. those without `rng` are shared with other
# implementations of `AbstractVarInfo`.
function assume(
rng::Random.AbstractRNG,
sampler::SampleFromPrior,
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::SimpleOrThreadSafeSimple,
)
value = init(rng, dist, sampler)
vi = BangBang.push!!(vi, vn, value, dist, sampler)
return value, Distributions.loglikelihood(dist, value), vi
end

function dot_assume(
dist::MultivariateDistribution,
var::AbstractMatrix,
vns::AbstractVector{<:VarName},
vi::SimpleOrThreadSafeSimple,
)
@assert length(dist) == size(var, 1)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
value = vi[vns]
lp = sum(zip(vns, eachcol(value))) do (vn, val)
return Distributions.logpdf(dist, val)
end
return value, lp, vi
end

function dot_assume(
dists::Union{Distribution,AbstractArray{<:Distribution}},
var::AbstractArray,
vns::AbstractArray{<:VarName},
vi::SimpleOrThreadSafeSimple,
)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
value = vi[vns]
lp = sum(Distributions.logpdf.(dists, value))
return value, lp, vi
# Transform if we're working in unconstrained space.
ist = istrans(vi, vn)
value_raw = ist ? Bijectors.link(dist, value) : value
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
return value, Bijectors.logpdf_with_trans(dist, value, ist), vi
end

function dot_assume(
Expand All @@ -401,15 +378,31 @@ function dot_assume(
)
f = (vn, dist) -> init(rng, dist, spl)
value = f.(vns, dists)
vi = BangBang.setindex!!(vi, value, vns)
lp = sum(Distributions.logpdf.(dists, value))

# Transform if we're working in transformed space.
ist = istrans(vi, first(vns))
value_raw = ist ? link.(dist, value) : value

# Update `vi`
vi = BangBang.setindex!!(vi, value_raw, vns)

# Compute logp.
lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist))
return value, lp, vi
end

# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals.
increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing
settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing
istrans(::SimpleVarInfo, vn::VarName) = false

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool)
return Setfield.@set vi.varinfo = settrans!!(vi, trans)
end

istrans(vi::SimpleVarInfo) = vi.istrans
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)

"""
values_as(varinfo[, Type])
Expand Down
Loading