Skip to content

Commit

Permalink
Perform invlinking in assume rather than implicitly in getindex (#360)
Browse files Browse the repository at this point in the history
Currently, in `assume`, etc., `invlink` is called implicitly in `getindex` using the distribution extracted from `vi`.

This has a couple of drawbacks:
1. We can only use the distribution for a particular `vn` stored in `vi` obtained during the initial run. This means that we can't even run models where the distributions has dynamic domains, i.e. the domain of a particular random variable is dependent on the realizations of other random variables.
2. We have to store the distribution for each `vn` in `vi`. This was fine when we only had `VarInfo` because we also need it for other functionality, but this is not the case in `SimpleVarInfo` (nor will it be).

So. In this PR we introduce a `getindex_raw` which is `getindex` but without `invlink` if it's already linked, and uses this within `assume`, etc. where we now use the distributions that are passed to `assume` rather than those stored in `vi`.

E.g. the following now works:

``` julia
julia> @model demo() = x ~ InverseGamma(2, 3)
demo (generic function with 2 methods)

julia> vi = SimpleVarInfo((x = 10.0, ), true)
SimpleVarInfo((x = 10.0,), 0.0, true)

julia> _, vi = DynamicPPL.evaluate!!(model, vi, DefaultContext())
(22026.465794806718, SimpleVarInfo{NamedTuple{(:x,), Tuple{Float64}}, Float64}((x = 10.0,), -17.80291162245307, true))
```

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Jul 23, 2022
1 parent 9ecf3dc commit 9937cb3
Show file tree
Hide file tree
Showing 20 changed files with 1,068 additions and 329 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
version:
- '1.3' # minimum supported version
- '1.6' # minimum supported version
- '1' # current stable version
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
Pkg.test() # resolver may fail with test time deps
Pkg.test(julia_args=["--depwarn=no"]) # resolver may fail with test time deps
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
# If we can't resolve that means this is incompatible by SemVer and this is fine
Expand Down
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.19.3"
version = "0.20.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -22,8 +24,10 @@ AbstractPPL = "0.5.1"
BangBang = "0.3"
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9, 0.10"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8"
MacroTools = "0.5.6"
Setfield = "0.7.1, 0.8"
ZygoteRules = "0.2"
julia = "1.3"
julia = "1.6"
22 changes: 21 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ NamedDist
DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule.

```@docs
DynamicPPL.TestUtils.test_sampler_demo_models
DynamicPPL.TestUtils.test_sampler
DynamicPPL.TestUtils.test_sampler_on_demo_models
DynamicPPL.TestUtils.test_sampler_continuous
DynamicPPL.TestUtils.marginal_mean_of_samples
```

```@docs
DynamicPPL.TestUtils.DEMO_MODELS
```

For every demo model, one can define the true log prior, log likelihood, and log joint probabilities.
Expand All @@ -115,6 +121,20 @@ DynamicPPL.TestUtils.loglikelihood_true
DynamicPPL.TestUtils.logjoint_true
```

And in the case where the model includes constrained variables, it can also be useful to define

```@docs
DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian
```

Finally, the following methods can also be of use:

```@docs
DynamicPPL.TestUtils.varnames
DynamicPPL.TestUtils.posterior_mean
```

## Advanced

### Variable names
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ using MacroTools: MacroTools
using Setfield: Setfield
using ZygoteRules: ZygoteRules

using DocStringExtensions

using Random: Random

import Base:
Expand Down
94 changes: 56 additions & 38 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(PriorContext(), right, vn, vi)
end
Expand All @@ -64,15 +64,15 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(LikelihoodContext(), right, vn, vi)
end
Expand All @@ -86,7 +86,7 @@ function tilde_assume(
)
if haskey(context.vars, getsym(vn))
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
settrans!(vi, false, vn)
settrans!!(vi, false, vn)
end
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
end
Expand Down Expand Up @@ -194,7 +194,7 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn]
r = vi[vn, dist]
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
end

Expand All @@ -211,16 +211,21 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
vi[vn] = vectorize(dist, r)
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r))
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
# Otherwise we just extract it.
r = vi[vn, dist]
end
else
r = init(rng, dist, sampler)
push!!(vi, vn, r, dist, sampler)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
Expand Down Expand Up @@ -286,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
Expand All @@ -305,19 +310,20 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
end
end

function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
return dot_assume(NoDist.(right), left, vn, vi)
return dot_assume(nodist(right), left, vn, vi)
end
function dot_tilde_assume(
rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi
)
return dot_assume(rng, sampler, NoDist.(right), vn, left, vi)
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
end

# `PriorContext`
Expand All @@ -326,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
else
dot_tilde_assume(PriorContext(), right, left, vn, vi)
Expand All @@ -345,7 +351,7 @@ function dot_tilde_assume(
var = get(context.vars, vn)
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
set_val!(vi, _vns, _right, _left)
settrans!.(Ref(vi), false, _vns)
settrans!!.((vi,), false, _vns)
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
else
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
Expand Down Expand Up @@ -383,14 +389,14 @@ function dot_assume(
vns::AbstractVector{<:VarName},
vi::AbstractVarInfo,
)
@assert length(dist) == size(var, 1)
@assert length(dist) == size(var, 1) "dimensionality of `var` ($(size(var, 1))) is incompatible with dimensionality of `dist` $(length(dist))"
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = vi[vns]
r = vi[vns, dist]
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
end
Expand All @@ -412,19 +418,21 @@ function dot_assume(
end

function dot_assume(
dists::Union{Distribution,AbstractArray{<:Distribution}},
dist::Distribution, var::AbstractArray, vns::AbstractArray{<:VarName}, vi
)
r = getindex.((vi,), vns, (dist,))
lp = sum(Bijectors.logpdf_with_trans.((dist,), r, istrans.((vi,), vns)))
return r, lp, vi
end

function dot_assume(
dists::AbstractArray{<:Distribution},
var::AbstractArray,
vns::AbstractArray{<:VarName},
vi,
)
# NOTE: We cannot work with `var` here because we might have a model of the form
#
# m = Vector{Float64}(undef, n)
# m .~ Normal()
#
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
r = reshape(vi[vec(vns)], size(vns))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
r = getindex.((vi,), vns, dists)
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end

Expand All @@ -438,7 +446,7 @@ function dot_assume(
)
r = get_and_set_val!(rng, vi, vns, dists, spl)
# Make sure `r` is not a matrix for multivariate distributions
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
Expand All @@ -462,19 +470,23 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r = vi[vns, dist]
end
else
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
push!!(vi, vn, r[:, i], dist, spl)
settrans!(vi, false, vn)
if istrans(vi)
push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl)
# `push!!` sets the trans-flag to `false` by default.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r[:, i], dist, spl)
end
end
end
return r
Expand All @@ -496,12 +508,13 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
settrans!(vi, false, vn)
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i]))
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
# r = reshape(vi[vec(vns)], size(vns))
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -511,8 +524,13 @@ function get_and_set_val!(
# 1. Figure out the broadcast size and use a `foreach`.
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
push!!.(Ref(vi), vns, r, dists, Ref(spl))
settrans!.(Ref(vi), false, vns)
if istrans(vi)
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
push!!.((vi,), vns, r, dists, (spl,))
end
end
return r
end
Expand Down
31 changes: 27 additions & 4 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ end

NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())

Base.length(dist::NamedDist) = Base.length(dist.dist)
Base.size(dist::NamedDist) = Base.size(dist.dist)

Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
Expand All @@ -24,12 +27,20 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.loglikelihood(dist.dist, x)
end

Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)

struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support}
dist::Td
end
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)

nodist(dist::Distribution) = NoDist(dist)
nodist(dists::AbstractArray) = nodist.(dists)

Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
Expand All @@ -40,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
Distributions.minimum(d::NoDist) = minimum(d.dist)
Distributions.maximum(d::NoDist) = maximum(d.dist)

Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0
Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
function Bijectors.logpdf_with_trans(
d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool
)
return 0
end
function Bijectors.logpdf_with_trans(
d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
)
return zeros(Int, size(x, 2))
end
Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
function Bijectors.logpdf_with_trans(
d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool
)
return 0
end

Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)
Loading

2 comments on commit 9937cb3

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64823

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.20.0 -m "<description of version>" 9937cb3ab1ba3cee9b8620bb82188608d78b1153
git push origin v0.20.0

Please sign in to comment.