-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Perform invlinking in assume rather than implicitly in getindex #360
Changes from 9 commits
14594d6
0bc279f
81ee12e
d39f87d
23f34cc
d3ec108
81782c9
12bfb42
c2c2417
2e2cb5c
f6c3fc4
d643a78
3cab7d9
8b870dc
0b304db
b146b11
d170d92
f46183b
7a78eec
70b3b70
a03e8cf
793c931
a9b12fd
b1d7f9a
be98961
3e1588b
7697fce
3139c62
3610658
27171ad
cd2d9d6
d948cb9
6a3e18f
d674478
26d2dbb
3fcba56
83a9448
356fa9c
13f037f
c7544e0
d1dccf1
ff7ff4a
2f1a2ff
d6311b7
ed2fa69
a82be56
7aacee5
96f128f
2e88d08
0f9765b
116c95c
d797e99
44b2f66
81cd881
2e14abd
12adc83
af3e6ba
f7501df
abcabf4
fdee509
e974c83
6c6d5f5
801bd4c
46f6f4c
bcb767b
c5be1c2
f266929
c2dbbaf
1abb46c
1086c6c
f2fb4a5
490d24e
6350ccd
951e4c3
dcd92c9
2922ffa
5266a4b
3c38710
ba92f3f
70c864c
66f41a9
eb2d6b5
e8cdb91
359d384
ab0a99b
dd10913
18d28cc
32b7aab
fb86231
f782fe2
7d3493d
f0f981b
1e0b946
faa0e42
9e7f493
0a9383b
d8b0a75
400f90f
9241acd
947e5c6
6f9be0d
5c5b9ce
e0797cc
4b0e0e1
5a73c87
bb43021
d5a48f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,7 @@ end | |
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) | ||
if haskey(context.vars, getsym(vn)) | ||
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) | ||
settrans!(vi, false, vn) | ||
settrans!!(vi, false, vn) | ||
end | ||
return tilde_assume(PriorContext(), right, vn, vi) | ||
end | ||
|
@@ -64,15 +64,15 @@ function tilde_assume( | |
) | ||
if haskey(context.vars, getsym(vn)) | ||
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) | ||
settrans!(vi, false, vn) | ||
settrans!!(vi, false, vn) | ||
end | ||
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) | ||
end | ||
|
||
function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) | ||
if haskey(context.vars, getsym(vn)) | ||
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) | ||
settrans!(vi, false, vn) | ||
settrans!!(vi, false, vn) | ||
end | ||
return tilde_assume(LikelihoodContext(), right, vn, vi) | ||
end | ||
|
@@ -86,7 +86,7 @@ function tilde_assume( | |
) | ||
if haskey(context.vars, getsym(vn)) | ||
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) | ||
settrans!(vi, false, vn) | ||
settrans!!(vi, false, vn) | ||
end | ||
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) | ||
end | ||
|
@@ -194,7 +194,9 @@ end | |
|
||
# fallback without sampler | ||
function assume(dist::Distribution, vn::VarName, vi) | ||
r = vi[vn] | ||
# x = vi[vn] | ||
r_raw = getindex_raw(vi, vn) | ||
r = maybe_invlink(vi, vn, dist, r_raw) | ||
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi | ||
end | ||
|
||
|
@@ -211,16 +213,23 @@ function assume( | |
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") | ||
unset_flag!(vi, vn, "del") | ||
r = init(rng, dist, sampler) | ||
vi[vn] = vectorize(dist, r) | ||
settrans!(vi, false, vn) | ||
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r)) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
else | ||
r = vi[vn] | ||
# Otherwise we just extract it. | ||
# r = vi[vn] | ||
r_raw = getindex_raw(vi, vn) | ||
r = maybe_invlink(vi, vn, dist, r_raw) | ||
end | ||
else | ||
r = init(rng, dist, sampler) | ||
push!!(vi, vn, r, dist, sampler) | ||
settrans!(vi, false, vn) | ||
if istrans(vi) | ||
push!!(vi, vn, link(dist, r), dist, sampler) | ||
# By default `push!!` sets the transformed flag to `false`. | ||
settrans!!(vi, true, vn) | ||
else | ||
push!!(vi, vn, r, dist, sampler) | ||
end | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi | ||
|
@@ -286,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, | |
var = get(context.vars, vn) | ||
_right, _left, _vns = unwrap_right_left_vns(right, var, vn) | ||
set_val!(vi, _vns, _right, _left) | ||
settrans!.(Ref(vi), false, _vns) | ||
settrans!!.(Ref(vi), false, _vns) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) | ||
else | ||
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) | ||
|
@@ -305,7 +314,7 @@ function dot_tilde_assume( | |
var = get(context.vars, vn) | ||
_right, _left, _vns = unwrap_right_left_vns(right, var, vn) | ||
set_val!(vi, _vns, _right, _left) | ||
settrans!.(Ref(vi), false, _vns) | ||
settrans!!.(Ref(vi), false, _vns) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) | ||
else | ||
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) | ||
|
@@ -326,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, | |
var = get(context.vars, vn) | ||
_right, _left, _vns = unwrap_right_left_vns(right, var, vn) | ||
set_val!(vi, _vns, _right, _left) | ||
settrans!.(Ref(vi), false, _vns) | ||
settrans!!.(Ref(vi), false, _vns) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) | ||
else | ||
dot_tilde_assume(PriorContext(), right, left, vn, vi) | ||
|
@@ -345,7 +354,7 @@ function dot_tilde_assume( | |
var = get(context.vars, vn) | ||
_right, _left, _vns = unwrap_right_left_vns(right, var, vn) | ||
set_val!(vi, _vns, _right, _left) | ||
settrans!.(Ref(vi), false, _vns) | ||
settrans!!.(Ref(vi), false, _vns) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) | ||
else | ||
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) | ||
|
@@ -390,7 +399,9 @@ function dot_assume( | |
# m .~ Normal() | ||
# | ||
# in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
r = vi[vns] | ||
# r = vi[vns] | ||
r_raw = getindex_raw(vi, vns) | ||
r = maybe_invlink(vi, vn, dist, r_raw) | ||
lp = sum(zip(vns, eachcol(r))) do (vn, ri) | ||
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) | ||
end | ||
|
@@ -423,7 +434,8 @@ function dot_assume( | |
# m .~ Normal() | ||
# | ||
# in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
r = reshape(vi[vec(vns)], size(vns)) | ||
r_raw = getindex_raw(vi, vec(vns)) | ||
r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns)) | ||
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) | ||
return r, lp, vi | ||
end | ||
|
@@ -462,19 +474,24 @@ function get_and_set_val!( | |
r = init(rng, dist, spl, n) | ||
for i in 1:n | ||
vn = vns[i] | ||
vi[vn] = vectorize(dist, r[:, i]) | ||
settrans!(vi, false, vn) | ||
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i])) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
end | ||
else | ||
r = vi[vns] | ||
r_raw = getindex_raw(vi, vns) | ||
r = maybe_invlink(vi, vns, dist, r_raw) | ||
end | ||
else | ||
r = init(rng, dist, spl, n) | ||
for i in 1:n | ||
vn = vns[i] | ||
push!!(vi, vn, r[:, i], dist, spl) | ||
settrans!(vi, false, vn) | ||
if istrans(vi) | ||
push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# `push!!` sets the trans-flag to `false` by default. | ||
setttrans!!(vi, true, vn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be another confusing design choice but probably should not be addressed in this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment above. |
||
else | ||
push!!(vi, vn, r[:, i], dist, spl) | ||
end | ||
end | ||
end | ||
return r | ||
|
@@ -496,12 +513,13 @@ function get_and_set_val!( | |
for i in eachindex(vns) | ||
vn = vns[i] | ||
dist = dists isa AbstractArray ? dists[i] : dists | ||
vi[vn] = vectorize(dist, r[i]) | ||
settrans!(vi, false, vn) | ||
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i])) | ||
setorder!(vi, vn, get_num_produce(vi)) | ||
end | ||
else | ||
r = reshape(vi[vec(vns)], size(vns)) | ||
# r = reshape(vi[vec(vns)], size(vns)) | ||
r_raw = getindex_raw(vi, vec(vns)) | ||
r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns))) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
else | ||
f = (vn, dist) -> init(rng, dist, spl) | ||
|
@@ -511,8 +529,13 @@ function get_and_set_val!( | |
# 1. Figure out the broadcast size and use a `foreach`. | ||
# 2. Define an anonymous function which returns `nothing`, which | ||
# we then broadcast. This will allocate a vector of `nothing` though. | ||
push!!.(Ref(vi), vns, r, dists, Ref(spl)) | ||
settrans!.(Ref(vi), false, vns) | ||
if istrans(vi) | ||
push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl)) | ||
# `push!!` sets the trans-flag to `false` by default. | ||
settrans!!.(Ref(vi), true, vns) | ||
else | ||
push!!.(Ref(vi), vns, r, dists, Ref(spl)) | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
end | ||
return r | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,11 +129,20 @@ ERROR: type NamedTuple has no field b | |
struct SimpleVarInfo{NT,T} <: AbstractVarInfo | ||
values::NT | ||
logp::T | ||
# TODO: Should we put this in the type instead? | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
istrans::Bool | ||
end | ||
|
||
SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) | ||
SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) | ||
SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) | ||
SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, false) | ||
function SimpleVarInfo{T}(θ) where {T<:Real} | ||
return SimpleVarInfo{typeof(θ),T}(θ, zero(T), false) | ||
end | ||
function SimpleVarInfo{T}(; kwargs...) where {T<:Real} | ||
return SimpleVarInfo{T}(NamedTuple(kwargs)) | ||
end | ||
function SimpleVarInfo(; kwargs...) | ||
return SimpleVarInfo{Float64}(NamedTuple(kwargs)) | ||
end | ||
SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) | ||
|
||
# Constructor from `Model`. | ||
|
@@ -158,8 +167,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) | |
end | ||
|
||
getlogp(vi::SimpleVarInfo) = vi.logp | ||
setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) | ||
acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) | ||
setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp | ||
acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp | ||
Comment on lines
+235
to
+236
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we gain something from not just defining There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I just went with |
||
|
||
""" | ||
keys(vi::SimpleVarInfo) | ||
|
@@ -179,7 +188,7 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) | |
end | ||
|
||
function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) | ||
return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") | ||
return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ", ", svi.istrans, ")") | ||
end | ||
|
||
# `NamedTuple` | ||
|
@@ -224,6 +233,11 @@ Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values | |
# TODO: Should we do better? | ||
Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values | ||
|
||
# Since we don't perform any transformations in `getindex` for `SimpleVarInfo` | ||
# we simply call `getindex` in `getindex_raw`. | ||
getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] | ||
getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] | ||
|
||
Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) | ||
function _haskey(nt::NamedTuple, vn::VarName) | ||
# LHS: Ensure that `nt` indeed has the property we want. | ||
|
@@ -337,58 +351,21 @@ function Base.eltype( | |
end | ||
|
||
# Context implementations | ||
function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) | ||
left = vi[vn] | ||
return left, Distributions.loglikelihood(dist, left), vi | ||
end | ||
|
||
# NOTE: Evaluations, i.e. those without `rng` are shared with other | ||
# implementations of `AbstractVarInfo`. | ||
function assume( | ||
rng::Random.AbstractRNG, | ||
sampler::SampleFromPrior, | ||
sampler::Union{SampleFromPrior,SampleFromUniform}, | ||
dist::Distribution, | ||
vn::VarName, | ||
vi::SimpleOrThreadSafeSimple, | ||
) | ||
value = init(rng, dist, sampler) | ||
vi = BangBang.push!!(vi, vn, value, dist, sampler) | ||
return value, Distributions.loglikelihood(dist, value), vi | ||
end | ||
|
||
function dot_assume( | ||
dist::MultivariateDistribution, | ||
var::AbstractMatrix, | ||
vns::AbstractVector{<:VarName}, | ||
vi::SimpleOrThreadSafeSimple, | ||
) | ||
@assert length(dist) == size(var, 1) | ||
# NOTE: We cannot work with `var` here because we might have a model of the form | ||
# | ||
# m = Vector{Float64}(undef, n) | ||
# m .~ Normal() | ||
# | ||
# in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
value = vi[vns] | ||
lp = sum(zip(vns, eachcol(value))) do (vn, val) | ||
return Distributions.logpdf(dist, val) | ||
end | ||
return value, lp, vi | ||
end | ||
|
||
function dot_assume( | ||
dists::Union{Distribution,AbstractArray{<:Distribution}}, | ||
var::AbstractArray, | ||
vns::AbstractArray{<:VarName}, | ||
vi::SimpleOrThreadSafeSimple, | ||
) | ||
# NOTE: We cannot work with `var` here because we might have a model of the form | ||
# | ||
# m = Vector{Float64}(undef, n) | ||
# m .~ Normal() | ||
# | ||
# in which case `var` will have `undef` elements, even if `m` is present in `vi`. | ||
value = vi[vns] | ||
lp = sum(Distributions.logpdf.(dists, value)) | ||
return value, lp, vi | ||
# Transform if we're working in unconstrained space. | ||
ist = istrans(vi, vn) | ||
value_raw = ist ? Bijectors.link(dist, value) : value | ||
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) | ||
return value, Bijectors.logpdf_with_trans(dist, value, ist), vi | ||
end | ||
|
||
function dot_assume( | ||
|
@@ -401,15 +378,31 @@ function dot_assume( | |
) | ||
f = (vn, dist) -> init(rng, dist, spl) | ||
value = f.(vns, dists) | ||
vi = BangBang.setindex!!(vi, value, vns) | ||
lp = sum(Distributions.logpdf.(dists, value)) | ||
|
||
# Transform if we're working in transformed space. | ||
ist = istrans(vi, first(vns)) | ||
value_raw = ist ? link.(dist, value) : value | ||
|
||
# Update `vi` | ||
vi = BangBang.setindex!!(vi, value_raw, vns) | ||
|
||
# Compute logp. | ||
lp = sum(Bijectors.logpdf_with_trans.(dists, value, ist)) | ||
return value, lp, vi | ||
end | ||
|
||
# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. | ||
increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing | ||
settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing | ||
istrans(::SimpleVarInfo, vn::VarName) = false | ||
|
||
# NOTE: We don't implement `settrans!!(vi, trans, vn)`. | ||
settrans!!(vi::SimpleVarInfo, trans::Bool) = Setfield.@set vi.istrans = trans | ||
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans::Bool) | ||
return Setfield.@set vi.varinfo = settrans!!(vi, trans) | ||
end | ||
|
||
istrans(vi::SimpleVarInfo) = vi.istrans | ||
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) | ||
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) | ||
|
||
""" | ||
values_as(varinfo[, Type]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it'll be nice if we can overload the transform used inside
logpdf_with_trans
in the future.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can actually replace the above with
right now. But I think if we do this we should remove
logpdf_with_trans
everywhere, so I wanted to defer that to another PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phipsgabler this is what I mentioned today. We can now replace
logpdf_with_trans
with ChangeOfVariables/LogDensityInterface API.#342