diff --git a/Project.toml b/Project.toml index 5830de123..f7be0257d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.28.4" +version = "0.29" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/Project.toml b/docs/Project.toml index 0746a3b5d..e79e44aba 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -14,6 +15,7 @@ DataStructures = "0.18" Distributions = "0.25" Documenter = "1" FillArrays = "0.13, 1" +ForwardDiff = "0.10" LogDensityProblems = "2" MCMCChains = "5, 6" StableRNGs = "1" diff --git a/docs/make.jl b/docs/make.jl index 109f28a9c..ebf15df06 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,14 +14,18 @@ DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=tr makedocs(; sitename="DynamicPPL", - format=Documenter.HTML(), + # The API index.html page is fairly large, and violates the default HTML page size + # threshold of 200KiB, so we double that. + format=Documenter.HTML(; size_threshold=2^10 * 400), modules=[DynamicPPL], pages=[ "Home" => "index.md", "API" => "api.md", "Tutorials" => ["tutorials/prob-interface.md"], + "Internals" => ["internals/transformations.md"], ], checkdocs=:exports, + doctest=false, ) deploydocs(; repo="github.com/TuringLang/DynamicPPL.jl.git", push_preview=true) diff --git a/docs/src/api.md b/docs/src/api.md index 2e8081a88..97c48316e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -218,13 +218,62 @@ Please see the documentation of [AbstractPPL.jl](https://github.com/TuringLang/A ### Data Structures of Variables -DynamicPPL provides different data structures for samples from the model and their log density. -All of them are subtypes of [`AbstractVarInfo`](@ref). +DynamicPPL provides different data structures used in for storing samples and accumulation of the log-probabilities, all of which are subtypes of [`AbstractVarInfo`](@ref). ```@docs AbstractVarInfo ``` +But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. + +#### `VarInfo` + +```@docs +VarInfo +TypedVarInfo +``` + +One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form. + +```@docs +link! +invlink! +``` + +```@docs +set_flag! +unset_flag! +is_flagged +``` + +For Gibbs sampling the following functions were added. + +```@docs +setgid! +updategid! +``` + +The following functions were used for sequential Monte Carlo methods. + +```@docs +get_num_produce +set_num_produce! +increment_num_produce! +reset_num_produce! +setorder! +set_retained_vns_del_by_spl! +``` + +```@docs +Base.empty! +``` + +#### `SimpleVarInfo` + +```@docs +SimpleVarInfo +``` + ### Common API #### Accumulation of log-probabilities @@ -241,7 +290,7 @@ resetlogp!! ```@docs keys getindex -DynamicPPL.getindex_raw +DynamicPPL.getindex_internal push!! empty!! isempty @@ -269,8 +318,9 @@ DynamicPPL.invlink DynamicPPL.link!! DynamicPPL.invlink!! DynamicPPL.default_transformation +DynamicPPL.link_transform +DynamicPPL.invlink_transform DynamicPPL.maybe_invlink_before_eval!! -DynamicPPL.reconstruct ``` #### Utils @@ -283,56 +333,6 @@ DynamicPPL.varname_leaves DynamicPPL.varname_and_value_leaves ``` -#### `SimpleVarInfo` - -```@docs -SimpleVarInfo -``` - -#### `VarInfo` - -Another data structure is [`VarInfo`](@ref). - -```@docs -VarInfo -TypedVarInfo -``` - -One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form. - -```@docs -link! -invlink! -``` - -```@docs -set_flag! -unset_flag! -is_flagged -``` - -For Gibbs sampling the following functions were added. - -```@docs -setgid! -updategid! -``` - -The following functions were used for sequential Monte Carlo methods. - -```@docs -get_num_produce -set_num_produce! -increment_num_produce! -reset_num_produce! -setorder! -set_retained_vns_del_by_spl! -``` - -```@docs -Base.empty! -``` - ### Evaluation Contexts Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref). diff --git a/docs/src/assets/images/transformations-assume-without-istrans.dot b/docs/src/assets/images/transformations-assume-without-istrans.dot new file mode 100644 index 000000000..6eb6865c3 --- /dev/null +++ b/docs/src/assets/images/transformations-assume-without-istrans.dot @@ -0,0 +1,17 @@ +digraph { + # `assume` block + subgraph cluster_assume { + label = "assume"; + fontname = "Courier"; + + assume [shape=box, label=< assume(varinfo, @varname(x), Normal())>, fontname="Courier"]; + without_linking_assume [shape=box, label="f = from_internal_transform(varinfo, varname, dist)", fontname="Courier"]; + with_logabsdetjac [shape=box, label="x, logjac = with_logabsdet_jacobian(f, assume_internal(varinfo, varname, dist))", fontname="Courier"]; + return_assume [shape=box, label=< return x, logpdf(dist, x) - logjac, varinfo >, style=dashed, fontname="Courier"]; + + assume -> without_linking_assume; + without_linking_assume -> with_logabsdetjac; + with_logabsdetjac -> return_assume; + } +} + diff --git a/docs/src/assets/images/transformations-assume-without-istrans.dot.png b/docs/src/assets/images/transformations-assume-without-istrans.dot.png new file mode 100644 index 000000000..f58727ad2 Binary files /dev/null and b/docs/src/assets/images/transformations-assume-without-istrans.dot.png differ diff --git a/docs/src/assets/images/transformations-assume-without-istrans.dot.svg b/docs/src/assets/images/transformations-assume-without-istrans.dot.svg new file mode 100644 index 000000000..be91de2ad --- /dev/null +++ b/docs/src/assets/images/transformations-assume-without-istrans.dot.svg @@ -0,0 +1,88 @@ + + + + + + +%3 + + + +tilde_node + +x ~ Normal() + + + +base_node + + varname = +@varname +(x) +dist = Normal() +x, varinfo = ... + + + +tilde_node->base_node + + +   +@model + + + +assume + +assume(varname, dist, varinfo) + + + +base_node->assume + + +  tilde-pipeline + + + +without_linking + +f = from_internal_transform(varinfo, varname, dist) + + + +assume->without_linking + + + + + +with_logabsdetjac + +x, logjac = with_logabsdet_jacobian(f, getindex_internal(varinfo, varname, dist)) + + + +without_linking->with_logabsdetjac + + + + + +return + + +return + x, logpdf(dist, x) - logjac, varinfo + + + +with_logabsdetjac->return + + + + + diff --git a/docs/src/assets/images/transformations-assume.dot b/docs/src/assets/images/transformations-assume.dot new file mode 100644 index 000000000..f1952b63f --- /dev/null +++ b/docs/src/assets/images/transformations-assume.dot @@ -0,0 +1,22 @@ +digraph { + # `assume` block + subgraph cluster_assume { + label = "assume"; + fontname = "Courier"; + + assume [shape=box, label=< assume(varinfo, @varname(x), Normal())>, fontname="Courier"]; + iflinked_assume [label=< if istrans(varinfo, varname) >, fontname="Courier"]; + without_linking_assume [shape=box, label="f = from_internal_transform(varinfo, varname, dist)", fontname="Courier"]; + with_linking_assume [shape=box, label="f = from_linked_internal_transform(varinfo, varname, dist)", fontname="Courier"]; + with_logabsdetjac [shape=box, label="x, logjac = with_logabsdet_jacobian(f, assume_internal(varinfo, varname, dist))", fontname="Courier"]; + return_assume [shape=box, label=< return x, logpdf(dist, x) - logjac, varinfo >, style=dashed, fontname="Courier"]; + + assume -> iflinked_assume; + iflinked_assume -> without_linking_assume [label=< false>, fontname="Courier"]; + iflinked_assume -> with_linking_assume [label=< true>, fontname="Courier"]; + without_linking_assume -> with_logabsdetjac; + with_linking_assume -> with_logabsdetjac; + with_logabsdetjac -> return_assume; + } +} + diff --git a/docs/src/assets/images/transformations-assume.dot.png b/docs/src/assets/images/transformations-assume.dot.png new file mode 100644 index 000000000..b8b0ec734 Binary files /dev/null and b/docs/src/assets/images/transformations-assume.dot.png differ diff --git a/docs/src/assets/images/transformations-getindex-with-dist.dot b/docs/src/assets/images/transformations-getindex-with-dist.dot new file mode 100644 index 000000000..e44fc0ce6 --- /dev/null +++ b/docs/src/assets/images/transformations-getindex-with-dist.dot @@ -0,0 +1,20 @@ +digraph { + # `getindex` block + subgraph cluster_getindex { + label = "getindex"; + fontname = "Courier"; + + getindex [shape=box, label=< x = getindex(varinfo, @varname(x), Normal()) >, fontname="Courier"]; + iflinked_getindex [label=< if istrans(varinfo, varname) >, fontname="Courier"]; + without_linking_getindex [shape=box, label="f = from_internal_transform(varinfo, varname, dist)", fontname="Courier"]; + with_linking_getindex [shape=box, label="f = from_linked_internal_transform(varinfo, varname, dist)", fontname="Courier"]; + return_getindex [shape=box, label=< return f(getindex_internal(varinfo, varname)) >, style=dashed, fontname="Courier"]; + + getindex -> iflinked_getindex; + iflinked_getindex -> without_linking_getindex [label=< false>, fontname="Courier"]; + iflinked_getindex -> with_linking_getindex [label=< true>, fontname="Courier"]; + without_linking_getindex -> return_getindex; + with_linking_getindex -> return_getindex; + } +} + diff --git a/docs/src/assets/images/transformations-getindex-with-dist.dot.png b/docs/src/assets/images/transformations-getindex-with-dist.dot.png new file mode 100644 index 000000000..381ba45a2 Binary files /dev/null and b/docs/src/assets/images/transformations-getindex-with-dist.dot.png differ diff --git a/docs/src/assets/images/transformations-getindex-without-dist.dot b/docs/src/assets/images/transformations-getindex-without-dist.dot new file mode 100644 index 000000000..38dd296e1 --- /dev/null +++ b/docs/src/assets/images/transformations-getindex-without-dist.dot @@ -0,0 +1,20 @@ +digraph { + # `getindex` block + subgraph cluster_getindex { + label = "getindex"; + fontname = "Courier"; + + getindex [shape=box, label=< x = getindex(varinfo, @varname(x)) >, fontname="Courier"]; + iflinked_getindex [label=< if istrans(varinfo, varname) >, fontname="Courier"]; + without_linking_getindex [shape=box, label="f = from_internal_transform(varinfo, varname)", fontname="Courier"]; + with_linking_getindex [shape=box, label="f = from_linked_internal_transform(varinfo, varname)", fontname="Courier"]; + return_getindex [shape=box, label=< return f(getindex_internal(varinfo, varname)) >, style=dashed, fontname="Courier"]; + + getindex -> iflinked_getindex; + iflinked_getindex -> without_linking_getindex [label=< false>, fontname="Courier"]; + iflinked_getindex -> with_linking_getindex [label=< true>, fontname="Courier"]; + without_linking_getindex -> return_getindex; + with_linking_getindex -> return_getindex; + } +} + diff --git a/docs/src/assets/images/transformations-getindex-without-dist.dot.png b/docs/src/assets/images/transformations-getindex-without-dist.dot.png new file mode 100644 index 000000000..a869326d3 Binary files /dev/null and b/docs/src/assets/images/transformations-getindex-without-dist.dot.png differ diff --git a/docs/src/assets/images/transformations.dot b/docs/src/assets/images/transformations.dot new file mode 100644 index 000000000..2ca40ddd0 --- /dev/null +++ b/docs/src/assets/images/transformations.dot @@ -0,0 +1,28 @@ +digraph { + # Nodes. + tilde_node [shape=box, label="x ~ Normal()", fontname="Courier"]; + base_node [shape=box, label=< vn = @varname(x)
dist = Normal()
x, vi = ... >, fontname="Courier"]; + assume [shape=box, label="assume(vn, dist, vi)", fontname="Courier"]; + + iflinked [label=< if istrans(vi, vn) >, fontname="Courier"]; + + without_linking [shape=box, label="f = from_internal_transform(vi, vn, dist)", styled=dashed, fontname="Courier"]; + with_linking [shape=box, label="f = from_linked_internal_transform(vi, vn, dist)", styled=dashed, fontname="Courier"]; + + with_logabsdetjac [shape=box, label="x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn, dist))", styled=dashed, fontname="Courier"]; + return [shape=box, label=< return x, logpdf(dist, x) - logjac, vi >, styled=dashed, fontname="Courier"]; + + # Edges. + tilde_node -> base_node [style=dashed, label=< @model>, fontname="Courier"] + base_node -> assume [style=dashed, label=" tilde-pipeline", fontname="Courier"]; + + assume -> iflinked; + + iflinked -> without_linking [label=< false>, fontname="Courier"]; + iflinked -> with_linking [label=< true>, fontname="Courier"]; + + without_linking -> with_logabsdetjac; + with_linking -> with_logabsdetjac; + + with_logabsdetjac -> return; +} diff --git a/docs/src/assets/images/transformations.dot.png b/docs/src/assets/images/transformations.dot.png new file mode 100644 index 000000000..1343a81e7 Binary files /dev/null and b/docs/src/assets/images/transformations.dot.png differ diff --git a/docs/src/assets/images/transformations.dot.svg b/docs/src/assets/images/transformations.dot.svg new file mode 100644 index 000000000..1e98f612d --- /dev/null +++ b/docs/src/assets/images/transformations.dot.svg @@ -0,0 +1,124 @@ + + + + + + +%3 + + + +tilde_node + +x ~ Normal() + + + +base_node + + vn = +@varname +(x) +dist = Normal() +x, vi = ... + + + +tilde_node->base_node + + +   +@model + + + +assume + +assume(vn, dist, vi) + + + +base_node->assume + + +  tilde-pipeline + + + +iflinked + + +if + istrans(vi, vn) + + + +assume->iflinked + + + + + +without_linking + +f = from_internal_transform(vi, vn, dist) + + + +iflinked->without_linking + + +   +false + + + +with_linking + +f = from_linked_internal_transform(vi, vn, dist) + + + +iflinked->with_linking + + +   +true + + + +with_logabsdetjac + +x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn, dist)) + + + +without_linking->with_logabsdetjac + + + + + +with_linking->with_logabsdetjac + + + + + +return + + +return + x, logpdf(dist, x) - logjac, vi + + + +with_logabsdetjac->return + + + + + diff --git a/docs/src/internals/transformations.md b/docs/src/internals/transformations.md new file mode 100644 index 000000000..d948290ec --- /dev/null +++ b/docs/src/internals/transformations.md @@ -0,0 +1,377 @@ +# Transforming variables + +## Motivation + +In a probabilistic programming language (PPL) such as DynamicPPL.jl, one crucial functionality for enabling a large number of inference algorithms to be implemented, in particular gradient-based ones, is the ability to work with "unconstrained" variables. + +For example, consider the following model: + +```julia +@model function demo() + s ~ InverseGamma(2, 3) + return m ~ Normal(0, √s) +end +``` + +Here we have two variables `s` and `m`, where `s` is constrained to be positive, while `m` can be any real number. + +For certain inference methods, it's necessary / much more convenient to work with an equivalent model to `demo` but where all the variables can take any real values (they're "unconstrained"). + +!!! note + + We write "unconstrained" with quotes because there are many ways to transform a constrained variable to an unconstrained one, *and* DynamicPPL can work with a much broader class of bijective transformations of variables, not just ones that go to the entire real line. But for MCMC, unconstraining is the most common transformation so we'll stick with that terminology. + +For a large family of constraints encountered in practice, it is indeed possible to transform a (partially) constrained model to a completely unconstrained one in such a way that sampling in the unconstrained space is equivalent to sampling in the constrained space. + +In DynamicPPL.jl, this is often referred to as *linking* (a term originating in the statistics literature) and is done using transformations from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl). + +For example, the above model could be transformed into (the following pseudo-code; it's not working code): + +```julia +@model function demo() + log_s ~ log(InverseGamma(2, 3)) + s = exp(log_s) + return m ~ Normal(0, √s) +end +``` + +Here `log_s` is an unconstrained variable, and `s` is a constrained variable that is a deterministic function of `log_s`. + +But to ensure that we stay consistent with what the user expects, DynamicPPL.jl does not actually transform the model as above, but instead makes use of transformed variables internally to achieve the same effect, when desired. + +In the end, we'll end up with something that looks like this: + +```@raw html +
+ +
+``` + +Below we'll see how this is done. + +## What do we need? + +There are two aspects to transforming from the internal representation of a variable in a `varinfo` to the representation wanted in the model: + + 1. Different implementations of [`AbstractVarInfo`](@ref) represent realizations of a model in different ways internally, so we need to transform from this internal representation to the desired representation in the model. For example, + + + [`VarInfo`](@ref) represents a realization of a model as a "flattened" / vector representation, regardless of the form of the variable in the model. + + [`SimpleVarInfo`](@ref) represents a realization of a model exactly as in the model (unless it has been transformed; we'll get to that later). + + 2. We need the ability to transform from "constrained space" to "unconstrained space", as we saw in the previous section. + +## Working example + +A good and non-trivial example to keep in mind throughout is the following model: + +```@example transformations-internal +using DynamicPPL, Distributions +@model demo_lkj() = x ~ LKJCholesky(2, 1.0) +``` + +`LKJCholesky` is a `LKJ(2, 1.0)` distribution, a distribution over correlation matrices (covariance matrices but with unit diagonal), but working directly with the Cholesky factorization of the correlation matrix rather than the correlation matrix itself (this is more numerically stable and computationally efficient). + +!!! note + + This is a particularly "annoying" case because the return-value is not a simple `Real` or `AbstractArray{<:Real}`, but rather a `LineraAlgebra.Cholesky` object which wraps a triangular matrix (whether it's upper- or lower-triangular depends on the instance). + +As mentioned, some implementations of `AbstractVarInfo`, e.g. [`VarInfo`](@ref), works with a "flattened" / vector representation of a variable, and so in this case we need two transformations: + + 1. From the `Cholesky` object to a vector representation. + 2. From the `Cholesky` object to an "unconstrained" / linked vector representation. + +And similarly, we'll need the inverses of these transformations. + +## From internal representation to model representation + +To go from the internal variable representation of an `AbstractVarInfo` to the variable representation wanted in the model, e.g. from a `Vector{Float64}` to `Cholesky` in the case of [`VarInfo`](@ref) in `demo_lkj`, we have the following methods: + +```@docs +DynamicPPL.to_internal_transform +DynamicPPL.from_internal_transform +``` + +These methods allow us to extract the internal-to-model transformation function depending on the `varinfo`, the variable, and the distribution of the variable: + + - `varinfo` + `vn` defines the internal representation of the variable. + - `dist` defines the representation expected within the model scope. + +!!! note + + If `vn` is not present in `varinfo`, then the internal representation is fully determined by `varinfo` alone. This is used when we're about to add a new variable to the `varinfo` and need to know how to represent it internally. + +Continuing from the example above, we can inspect the internal representation of `x` in `demo_lkj` with [`VarInfo`](@ref) using [`DynamicPPL.getindex_internal`](@ref): + +```@example transformations-internal +model = demo_lkj() +varinfo = VarInfo(model) +x_internal = DynamicPPL.getindex_internal(varinfo, @varname(x)) +``` + +```@example transformations-internal +f_from_internal = DynamicPPL.from_internal_transform( + varinfo, @varname(x), LKJCholesky(2, 1.0) +) +f_from_internal(x_internal) +``` + +Let's confirm that this is the same as `varinfo[@varname(x)]`: + +```@example transformations-internal +x_model = varinfo[@varname(x)] +``` + +Similarly, we can go from the model representation to the internal representation: + +```@example transformations-internal +f_to_internal = DynamicPPL.to_internal_transform(varinfo, @varname(x), LKJCholesky(2, 1.0)) + +f_to_internal(x_model) +``` + +It's also useful to see how this is done in [`SimpleVarInfo`](@ref): + +```@example transformations-internal +simple_varinfo = SimpleVarInfo(varinfo) +DynamicPPL.getindex_internal(simple_varinfo, @varname(x)) +``` + +Here see that the internal representation is exactly the same as the model representation, and so we'd expect `from_internal_transform` to be the `identity` function: + +```@example transformations-internal +DynamicPPL.from_internal_transform(simple_varinfo, @varname(x), LKJCholesky(2, 1.0)) +``` + +Great! + +## From *unconstrained* internal representation to model representation + +In addition to going from internal representation to model representation of a variable, we also need to be able to go from the *unconstrained* internal representation to the model representation. + +For this, we have the following methods: + +```@docs +DynamicPPL.to_linked_internal_transform +DynamicPPL.from_linked_internal_transform +``` + +These are very similar to [`DynamicPPL.to_internal_transform`](@ref) and [`DynamicPPL.from_internal_transform`](@ref), but here the internal representation is also linked / "unconstrained". + +Continuing from the example above: + +```@example transformations-internal +f_to_linked_internal = DynamicPPL.to_linked_internal_transform( + varinfo, @varname(x), LKJCholesky(2, 1.0) +) + +x_linked_internal = f_to_linked_internal(x_model) +``` + +```@example transformations-internal +f_from_linked_internal = DynamicPPL.from_linked_internal_transform( + varinfo, @varname(x), LKJCholesky(2, 1.0) +) + +f_from_linked_internal(x_linked_internal) +``` + +Here we see a significant difference between the linked representation and the non-linked representation: the linked representation is only of length 1, whereas the non-linked representation is of length 4. This is because we actually only need a single element to represent a 2x2 correlation matrix, as the diagonal elements are always 1 *and* it's symmetric. + +We can also inspect the transforms themselves: + +```@example transformations-internal +f_from_internal +``` + +vs. + +```@example transformations-internal +f_from_linked_internal +``` + +Here we see that `f_from_linked_internal` is a single function taking us directly from the linked representation to the model representation, whereas `f_from_internal` is a composition of a few functions: one reshaping the underlying length 4 array into 2x2 matrix, and the other converting this matrix into a `Cholesky`, as required to be compatible with `LKJCholesky(2, 1.0)`. + +## Why do we need both `to_internal_transform` and `to_linked_internal_transform`? + +One might wonder why we need both `to_internal_transform` and `to_linked_internal_transform` instead of just a single `to_internal_transform` which returns the "standard" internal representation if the variable is not linked / "unconstrained" and the linked / "unconstrained" internal representation if it is. + +That is, why can't we just do + +```@raw html +
+ +
+``` + +Unfortunately, this is not possible in general. Consider for example the following model: + +```@example transformations-internal +@model function demo_dynamic_constraint() + m ~ Normal() + x ~ truncated(Normal(); lower=m) + + return (m=m, x=x) +end +``` + +Here the variable `x` is constrained to be in the domain `(m, Inf)`, where `m` is sampled according to a `Normal`. + +```@example transformations-internal +model = demo_dynamic_constraint() +varinfo = VarInfo(model) +varinfo[@varname(m)], varinfo[@varname(x)] +``` + +We see that the realization of `x` is indeed greater than `m`, as expected. + +But what if we [`link`](@ref) this `varinfo` so that we end up working on an "unconstrained" space, i.e. both `m` and `x` can take on any values in `(-Inf, Inf)`: + +```@example transformations-internal +varinfo_linked = link(varinfo, model) +varinfo_linked[@varname(m)], varinfo_linked[@varname(x)] +``` + +Still get the same values, as expected, since internally `varinfo` transforms from the linked internal representation to the model representation. + +But what if we change the value of `m`, to, say, a bit larger than `x`? + +```@example transformations-internal +# Update realization for `m` in `varinfo_linked`. +varinfo_linked[@varname(m)] = varinfo_linked[@varname(x)] + 1 +varinfo_linked[@varname(m)], varinfo_linked[@varname(x)] +``` + +Now we see that the constraint `m < x` is no longer satisfied! + +Hence one might expect that if we try to compute, say, the [`logjoint`](@ref) using `varinfo_linked` with this "invalid" realization, we'll get an error: + +```@example transformations-internal +logjoint(model, varinfo_linked) +``` + +But we don't! In fact, if we look at the actual value used within the model + +```@example transformations-internal +first(DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext())) +``` + +we see that we indeed satisfy the constraint `m < x`, as desired. + +!!! warning + + One shouldn't be setting variables in a linked `varinfo` willy-nilly directly like this unless one knows that the value will be compatible with the constraints of the model. + +The reason for this is that internally in a model evaluation, we construct the transformation from the internal to the model representation based on the *current* realizations in the model! That is, we take the `dist` in a `x ~ dist` expression _at model evaluation time_ and use that to construct the transformation, thus allowing it to change between model evaluations without invalidating the transformation. + +But to be able to do this, we need to know whether the variable is linked / "unconstrained" or not, since the transformation is different in the two cases. Hence we need to be able to determine this at model evaluation time. Hence the internals end up looking something like this: + +```julia +if istrans(varinfo, varname) + from_linked_internal_transform(varinfo, varname, dist) +else + from_internal_transform(varinfo, varname, dist) +end +``` + +That is, if the variable is linked / "unconstrained", we use the [`DynamicPPL.from_linked_internal_transform`](@ref), otherwise we use [`DynamicPPL.from_internal_transform`](@ref). + +And so the earlier diagram becomes: + +```@raw html +
+ +
+``` + +!!! note + + If the support of `dist` was constant, this would not be necessary since we could just determine the transformation at the time of `varinfo_linked = link(varinfo, model)` and define this as the `from_internal_transform` for all subsequent evaluations. However, since the support of `dist` is *not* constant in general, we need to be able to determine the transformation at the time of the evaluation *and* thus whether we should construct the transformation from the linked internal representation or the non-linked internal representation. This is annoying, but necessary. + +This is also the reason why we have two definitions of `getindex`: + + - [`getindex(::AbstractVarInfo, ::VarName, ::Distribution)`](@ref): used internally in model evaluations with the `dist` in a `x ~ dist` expression. + - [`getindex(::AbstractVarInfo, ::VarName)`](@ref): used externally by the user to get the realization of a variable. + +For `getindex` we have the following diagram: + +```@raw html +
+ +
+``` + +While if `dist` is not provided, we have: + +```@raw html +
+ +
+``` + +Notice that `dist` is not present here, but otherwise the diagrams are the same. + +!!! warning + + This does mean that the `getindex(varinfo, varname)` might not be the same as the `getindex(varinfo, varname, dist)` that occurs within a model evaluation! This can be confusing, but as outlined above, we do want to allow the `dist` in a `x ~ dist` expression to "override" whatever transformation `varinfo` might have. + +## Other functionalities + +There are also some additional methods for transforming between representations that are all automatically implemented from [`DynamicPPL.from_internal_transform`](@ref), [`DynamicPPL.from_linked_internal_transform`](@ref) and their siblings, and thus don't need to be implemented manually. + +Convenience methods for constructing transformations: + +```@docs +DynamicPPL.from_maybe_linked_internal_transform +DynamicPPL.to_maybe_linked_internal_transform +DynamicPPL.internal_to_linked_internal_transform +DynamicPPL.linked_internal_to_internal_transform +``` + +Convenience methods for transforming between representations without having to explicitly construct the transformation: + +```@docs +DynamicPPL.to_maybe_linked_internal +DynamicPPL.from_maybe_linked_internal +``` + +# Supporting a new distribution + +To support a new distribution, one needs to implement for the desired `AbstractVarInfo` the following methods: + + - [`DynamicPPL.from_internal_transform`](@ref) + - [`DynamicPPL.from_linked_internal_transform`](@ref) + +At the time of writing, [`VarInfo`](@ref) is the one that is most commonly used, whose internal representation is always a `Vector`. In this scenario, one can just implement the following methods instead: + +```@docs +DynamicPPL.from_vec_transform(::Distribution) +DynamicPPL.from_linked_vec_transform(::Distribution) +``` + +These are used internally by [`VarInfo`](@ref). + +Optionally, if `inverse` of the above is expensive to compute, one can also implement: + + - [`DynamicPPL.to_internal_transform`](@ref) + - [`DynamicPPL.to_linked_internal_transform`](@ref) + +And similarly, there are corresponding to-methods for the `from_*_vec_transform` variants too + +```@docs +DynamicPPL.to_vec_transform +DynamicPPL.to_linked_vec_transform +``` + +!!! warning + + Whatever the resulting transformation is, it should be invertible, i.e. implement `InverseFunctions.inverse`, and have a well-defined log-abs-det Jacobian, i.e. implement `ChangesOfVariables.with_logabsdet_jacobian`. + +# TL;DR + + - DynamicPPL.jl has three representations of a variable: the **model representation**, the **internal representation**, and the **linked internal representation**. + + + The **model representation** is the representation of the variable as it appears in the model code / is expected by the `dist` on the right-hand-side of the `~` in the model code. + + The **internal representation** is the representation of the variable as it appears in the `varinfo`, which varies between implementations of [`AbstractVarInfo`](@ref), e.g. a `Vector` in [`VarInfo`](@ref). This can be converted to the model representation by [`DynamicPPL.from_internal_transform`](@ref). + + The **linked internal representation** is the representation of the variable as it appears in the `varinfo` after [`link`](@ref)ing. This can be converted to the model representation by [`DynamicPPL.from_linked_internal_transform`](@ref). + + - Having separation between *internal* and *linked internal* is necessary because transformations might be constructed at the time of model evaluation, and thus we need to know whether to construct the transformation from the internal representation or the linked internal representation. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index fffaa5967..eb027b45b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -80,11 +80,7 @@ export AbstractVarInfo, # Compiler @model, # Utilities - vectorize, - reconstruct, - reconstruct!, init, - vectorize, OrderedDict, # Model Model, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7d9273941..7ddd09b2e 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -144,9 +144,7 @@ Return an iterator over all `vns` in `vi`. Return the current value(s) of `vn` (`vns`) in `vi` in the support of its (their) distribution(s). -If `dist` is specified, the value(s) will be reshaped accordingly. - -See also: [`getindex_raw(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) +If `dist` is specified, the value(s) will be massaged into the representation expected by `dist`. """ Base.getindex """ @@ -164,22 +162,14 @@ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) Base.getindex(vi::AbstractVarInfo, ::AbstractSampler) = vi[:] """ - getindex_raw(vi::AbstractVarInfo, vn::VarName[, dist::Distribution]) - getindex_raw(vi::AbstractVarInfo, vns::Vector{<:VarName}[, dist::Distribution]) - -Return the current value(s) of `vn` (`vns`) in `vi`. + getindex_internal(vi::AbstractVarInfo, vn::VarName) + getindex_internal(vi::AbstractVarInfo, vns::Vector{<:VarName}) -If `dist` is specified, the value(s) will be reshaped accordingly. +Return the current value(s) of `vn` (`vns`) in `vi` as represented internally in `vi`. See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) - -!!! note - The difference between `getindex(vi, vn, dist)` and `getindex_raw` is that - `getindex` will also transform the value(s) to the support of the distribution(s). - This is _not_ the case for `getindex_raw`. - """ -function getindex_raw end +function getindex_internal end """ push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) @@ -570,7 +560,7 @@ If `t` is not provided, `default_transformation(model, vi)` will be used. See also: [`default_transformation`](@ref), [`invlink`](@ref). """ -link(vi::AbstractVarInfo, model::Model) = link(deepcopy(vi), SampleFromPrior(), model) +link(vi::AbstractVarInfo, model::Model) = link(vi, SampleFromPrior(), model) function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) return link(t, deepcopy(vi), SampleFromPrior(), model) end @@ -753,110 +743,175 @@ function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::Abstrac return unflatten(varinfo, sampler, θ) end -# TODO: Clean up all this linking stuff once and for all! """ - with_logabsdet_jacobian_and_reconstruct([f, ]dist, x) + to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) -Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting -value is reconstructed to the correct type and shape according to `dist`. +Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`. """ -function with_logabsdet_jacobian_and_reconstruct(f, dist, x) - x_recon = reconstruct(f, dist, x) - return with_logabsdet_jacobian(f, x_recon) +function to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) + f = to_maybe_linked_internal_transform(vi, vn, dist) + return f(val) end -# NOTE: Necessary to handle product distributions of `Dirichlet` and similar. -function with_logabsdet_jacobian_and_reconstruct( - f::Bijectors.Inverse{<:Bijectors.SimplexBijector}, dist, y -) - (d, ns...) = size(dist) - yreshaped = reshape(y, d - 1, ns...) - x, logjac = with_logabsdet_jacobian(f, yreshaped) - return x, logjac +""" + from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) + +Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`. +""" +function from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) + f = from_maybe_linked_internal_transform(vi, vn, dist) + return f(val) end -# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can -# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden. -# NOTE: `reconstruct` is no-op if `val` is already of correct shape. """ - reconstruct_and_link(dist, val) - reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val) + invlink_with_logpdf(varinfo::AbstractVarInfo, vn::VarName, dist[, x]) -Return linked `val` but reconstruct before linking, if necessary. +Invlink `x` and compute the logpdf under `dist` including correction from +the invlink-transformation. -Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily -return a reconstructed value, i.e. a value of the same type and shape as expected -by `dist`. +If `x` is not provided, `getindex_internal(vi, vn)` will be used. -See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref). +!!! warning + The input value `x` should be according to the internal representation of + `varinfo`, e.g. the value returned by `getindex_internal(vi, vn)`. """ -reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val)) -reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val) -function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val) - return reconstruct_and_link(dist, val) +function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist) + return invlink_with_logpdf(vi, vn, dist, getindex_internal(vi, vn)) +end +function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) + f = from_maybe_linked_internal_transform(vi, vn, dist) + x, logjac = with_logabsdet_jacobian(f, y) + return x, logpdf(dist, x) + logjac end +# Legacy code that is currently overloaded for the sake of simplicity. +# TODO: Remove when possible. +increment_num_produce!(::AbstractVarInfo) = nothing +setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing + """ - invlink_and_reconstruct(dist, val) - invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) -Return invlinked and reconstructed `val`. +Return a transformation that transforms from the internal representation of `vn` with `dist` +in `varinfo` to a representation compatible with `dist`. -See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref). +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. """ -invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val)) -function invlink_and_reconstruct(dist, val) - return invlink_and_reconstruct(invlink_transform(dist), dist, val) -end -function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val) - return invlink_and_reconstruct(dist, val) -end +function from_internal_transform end """ - maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + from_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) -Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`. +Return a transformation that transforms from the linked internal representation of `vn` with `dist` +in `varinfo` to a representation compatible with `dist`. + +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. """ -function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val) - return if istrans(vi, vn) - reconstruct_and_link(vi, vn, dist, val) +function from_linked_internal_transform end + +""" + from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) + +Return a transformation that transforms from the possibly linked internal representation of `vn` with `dist`n +in `varinfo` to a representation compatible with `dist`. + +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. +""" +function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + return if istrans(varinfo, vn) + from_linked_internal_transform(varinfo, vn, dist) else - reconstruct(dist, val) + from_internal_transform(varinfo, vn, dist) + end +end +function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName) + return if istrans(varinfo, vn) + from_linked_internal_transform(varinfo, vn) + else + from_internal_transform(varinfo, vn) end end """ - maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) + to_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) -Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`. +Return a transformation that transforms from a representation compatible with `dist` to the +internal representation of `vn` with `dist` in `varinfo`. + +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. """ -function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) - return if istrans(vi, vn) - invlink_and_reconstruct(vi, vn, dist, val) - else - reconstruct(dist, val) - end +function to_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + return inverse(from_internal_transform(varinfo, vn, dist)) +end +function to_internal_transform(varinfo::AbstractVarInfo, vn::VarName) + return inverse(from_internal_transform(varinfo, vn)) end """ - invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x]) + to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) -Invlink `x` and compute the logpdf under `dist` including correction from -the invlink-transformation. +Return a transformation that transforms from a representation compatible with `dist` to the +linked internal representation of `vn` with `dist` in `varinfo`. -If `x` is not provided, `getval(vi, vn)` will be used. +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. """ -function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist) - return invlink_with_logpdf(vi, vn, dist, getval(vi, vn)) +function to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + return inverse(from_linked_internal_transform(varinfo, vn, dist)) end -function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) - # NOTE: Will this cause type-instabilities or will union-splitting save us? - f = istrans(vi, vn) ? invlink_transform(dist) : identity - x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y) - return x, logpdf(dist, x) + logjac +function to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName) + return inverse(from_linked_internal_transform(varinfo, vn)) end -# Legacy code that is currently overloaded for the sake of simplicity. -# TODO: Remove when possible. -increment_num_produce!(::AbstractVarInfo) = nothing -setgid!(vi::AbstractVarInfo, gid::Selector, vn::VarName) = nothing +""" + to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) + +Return a transformation that transforms from a representation compatible with `dist` to a +possibly linked internal representation of `vn` with `dist` in `varinfo`. + +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. +""" +function to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + return inverse(from_maybe_linked_internal_transform(varinfo, vn, dist)) +end +function to_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName) + return inverse(from_maybe_linked_internal_transform(varinfo, vn)) +end + +""" + internal_to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + +Return a transformation that transforms from the internal representation of `vn` with `dist` +in `varinfo` to a _linked_ internal representation of `vn` with `dist` in `varinfo`. + +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. +""" +function internal_to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + f_from_internal = from_internal_transform(varinfo, vn, dist) + f_to_linked_internal = to_linked_internal_transform(varinfo, vn, dist) + return f_to_linked_internal ∘ f_from_internal +end +function internal_to_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName) + f_from_internal = from_internal_transform(varinfo, vn) + f_to_linked_internal = to_linked_internal_transform(varinfo, vn) + return f_to_linked_internal ∘ f_from_internal +end + +""" + linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) + +Return a transformation that transforms from a _linked_ internal representation of `vn` with `dist` +in `varinfo` to the internal representation of `vn` with `dist` in `varinfo`. + +If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. +""" +function linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) + f_from_linked_internal = from_linked_internal_transform(varinfo, vn, dist) + f_to_internal = to_internal_transform(varinfo, vn, dist) + return f_to_internal ∘ f_from_linked_internal +end + +function linked_internal_to_internal_transform(varinfo::AbstractVarInfo, vn::VarName) + f_from_linked_internal = from_linked_internal_transform(varinfo, vn) + f_to_internal = to_internal_transform(varinfo, vn) + return f_to_internal ∘ f_from_linked_internal +end diff --git a/src/compiler.jl b/src/compiler.jl index 55adc534c..90220cbf5 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -647,8 +647,6 @@ function namedtuple_from_splitargs(splitargs) return :(NamedTuple{$names_expr}($vals)) end -is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") - """ build_output(modeldef, linenumbernode) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9fcd2e310..13231837f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -79,7 +79,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) + vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) @@ -88,7 +88,7 @@ function tilde_assume( rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi ) if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) + vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) @@ -96,7 +96,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) + vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) @@ -110,7 +110,7 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) + vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) @@ -228,6 +228,7 @@ function assume(dist::Distribution, vn::VarName, vi) return r, logp, vi end +# TODO: Remove this thing. # SampleFromPrior and SampleFromUniform function assume( rng::Random.AbstractRNG, @@ -241,9 +242,8 @@ function assume( if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - BangBang.setindex!!( - vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn - ) + f = to_maybe_linked_internal_transform(vi, vn, dist) + BangBang.setindex!!(vi, f(r), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -252,7 +252,8 @@ function assume( else r = init(rng, dist, sampler) if istrans(vi) - push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler) + f = to_linked_internal_transform(vi, dist) + push!!(vi, vn, f(r), dist, sampler) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) else @@ -491,6 +492,19 @@ function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ) end +# HACK: These methods are only used in the `get_and_set_val!` methods below. +# FIXME: Remove these. +function _link_broadcast_new(vi, vn, dist, r) + b = to_linked_internal_transform(vi, dist) + return b(r) +end + +function _maybe_invlink_broadcast(vi, vn, dist) + xvec = getindex_internal(vi, vn) + b = from_maybe_linked_internal_transform(vi, vn, dist) + return b(xvec) +end + function get_and_set_val!( rng, vi::VarInfoOrThreadSafeVarInfo, @@ -506,11 +520,8 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - setindex!!( - vi, - vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])), - vn, - ) + f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) + setindex!!(vi, f_link_maybe(r[:, i]), vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -521,7 +532,8 @@ function get_and_set_val!( for i in 1:n vn = vns[i] if istrans(vi) - push!!(vi, vn, Bijectors.link(dist, r[:, i]), dist, spl) + ri_linked = _link_broadcast_new(vi, vn, dist, r[:, i]) + push!!(vi, vn, ri_linked, dist, spl) # `push!!` sets the trans-flag to `false` by default. settrans!!(vi, true, vn) else @@ -548,17 +560,13 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - setindex!!( - vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn - ) + f_link_maybe = to_maybe_linked_internal_transform(vi, vn, dist) + setindex!!(vi, f_link_maybe(r[i]), vn) setorder!(vi, vn, get_num_produce(vi)) end else - # r = reshape(vi[vec(vns)], size(vns)) - # FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)` - # and fix the lines below. - r_raw = getindex_raw(vi, vec(vns)) - r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns))) + rs = _maybe_invlink_broadcast.((vi,), vns, dists) + r = reshape(rs, size(vns)) end else f = (vn, dist) -> init(rng, dist, spl) @@ -569,10 +577,10 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) + push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. # FIXME: This is not great. - acclogp_assume!!(vi, sum(logabsdetjac.(bijector.(dists), r))) + acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else @@ -602,8 +610,7 @@ function set_val!( ) @assert size(val) == size(vns) foreach(CartesianIndices(val)) do ind - dist = dists isa AbstractArray ? dists[ind] : dists - setindex!!(vi, vectorize(dist, val[ind]), vns[ind]) + setindex!!(vi, tovec(val[ind]), vns[ind]) end return val end diff --git a/src/extract_priors.jl b/src/extract_priors.jl index bb6721a9c..dd5aeeb04 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -110,9 +110,24 @@ julia> length(extract_priors(rng, model)[@varname(x)]) 9 ``` """ -extract_priors(model::Model) = extract_priors(Random.default_rng(), model) +extract_priors(args::Union{Model,AbstractVarInfo}...) = + extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) context = PriorExtractorContext(SamplingContext(rng)) evaluate!!(model, VarInfo(), context) return context.priors end + +""" + extract_priors(model::Model, varinfo::AbstractVarInfo) + +Extract the priors from a model. + +This is done by evaluating the model at the values present in `varinfo` +and recording the distributions that are present at each tilde statement. +""" +function extract_priors(model::Model, varinfo::AbstractVarInfo) + context = PriorExtractorContext(DefaultContext()) + evaluate!!(model, deepcopy(varinfo), context) + return context.priors +end diff --git a/src/model.jl b/src/model.jl index a7c48017a..09c0c1be1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -973,6 +973,8 @@ function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractCo return model.f(args...; kwargs...) end +is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") + """ make_evaluate_args_and_kwargs(model, varinfo, context) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a6b907701..d8afb9cec 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -248,6 +248,16 @@ function SimpleVarInfo{T}( return SimpleVarInfo(values, convert(T, getlogp(vi))) end +function untyped_simple_varinfo(model::Model) + varinfo = SimpleVarInfo(OrderedDict()) + return last(evaluate!!(model, varinfo, SamplingContext())) +end + +function typed_simple_varinfo(model::Model) + varinfo = SimpleVarInfo{Float64}() + return last(evaluate!!(model, varinfo, SamplingContext())) +end + unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) logp = getlogp(svi) @@ -295,23 +305,17 @@ function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end -# `NamedTuple` function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn)) + return from_maybe_linked_internal(vi, vn, dist, getindex(vi, vn)) end function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) vals_linked = mapreduce(vcat, vns) do vn getindex(vi, vn, dist) end - return reconstruct(dist, vals_linked, length(vns)) + return recombine(dist, vals_linked, length(vns)) end -Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) - -# `AbstractDict` -function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) - return nested_getindex(vi.values, vn) -end +Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. @@ -323,22 +327,12 @@ Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getinde Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) -# Since we don't perform any transformations in `getindex` for `SimpleVarInfo` -# we simply call `getindex` in `getindex_raw`. -getindex_raw(vi::SimpleVarInfo, vn::VarName) = vi[vn] -function getindex_raw(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return reconstruct(dist, getindex_raw(vi, vn)) -end -getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}) = vi[vns] -function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - # `reconstruct` expects a flattened `Vector` regardless of the type of `dist`, so we `vcat` everything. - vals = mapreduce(Base.Fix1(getindex_raw, vi), vcat, vns) - return reconstruct(dist, vals, length(vns)) +getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) +# `AbstractDict` +function getindex_internal(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) + return nested_getindex(vi.values, vn) end -# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`. -getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn) - Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) @@ -484,7 +478,7 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = maybe_reconstruct_and_link(vi, vn, dist, value) + value_raw = to_maybe_linked_internal(vi, vn, dist, value) vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end @@ -502,9 +496,9 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - maybe_reconstruct_and_link.((vi,), vns, (dists,), value) + to_maybe_linked_internal.((vi,), vns, (dists,), value) else - maybe_reconstruct_and_link.((vi,), vns, dists, value) + to_maybe_linked_internal.((vi,), vns, dists, value) end # Update `vi` @@ -531,7 +525,7 @@ function dot_assume( # Update `vi`. for (vn, val) in zip(vns, eachcol(value)) - val_linked = maybe_reconstruct_and_link(vi, vn, dist, val) + val_linked = to_maybe_linked_internal(vi, vn, dist, val) vi = BangBang.setindex!!(vi, val_linked, vn) end @@ -561,7 +555,7 @@ values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} isempty(vi) && return T[] - return mapreduce(vectorize, vcat, values(vi.values)) + return mapreduce(tovec, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) @@ -697,6 +691,15 @@ function invlink!!( return settrans!!(vi_new, NoTransformation()) end +# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. +from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity +from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity +# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`? +from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity +function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) + return invlink_transform(dist) +end + # Threadsafe stuff. # For `SimpleVarInfo` we don't really need `Ref` so let's not use it. function ThreadSafeVarInfo(vi::SimpleVarInfo) diff --git a/src/test_utils.jl b/src/test_utils.jl index bf7be0a9a..6f7481c40 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -18,9 +18,9 @@ using DynamicPPL: varname_leaves, update_values!! Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`. """ -function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal, kwargs...) +function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns - @test isequal(vi[vn], get(vals, vn); kwargs...) + @test compare(vi[vn], get(vals, vn); kwargs...) end end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index cee369d8b..4fbf0d124 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -159,21 +159,6 @@ function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::D end getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) -getindex_raw(vi::ThreadSafeVarInfo, ::Colon) = getindex_raw(vi.varinfo, Colon()) -getindex_raw(vi::ThreadSafeVarInfo, vn::VarName) = getindex_raw(vi.varinfo, vn) -function getindex_raw(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) - return getindex_raw(vi.varinfo, vns) -end -function getindex_raw(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) - return getindex_raw(vi.varinfo, vn, dist) -end -function getindex_raw( - vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution -) - return getindex_raw(vi.varinfo, vns, dist) -end -getindex_raw(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex_raw(vi.varinfo, spl) - function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end @@ -221,7 +206,7 @@ end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) -getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) +getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) @@ -239,3 +224,21 @@ function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVa varinfo_left.varinfo, varinfo_right.varinfo ) end + +function invlink_with_logpdf(vi::ThreadSafeVarInfo, vn::VarName, dist, y) + return invlink_with_logpdf(vi.varinfo, vn, dist, y) +end + +function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) + return from_internal_transform(varinfo.varinfo, vn) +end +function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) + return from_internal_transform(varinfo.varinfo, vn, dist) +end + +function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) + return from_linked_internal_transform(varinfo.varinfo, vn) +end +function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) + return from_linked_internal_transform(varinfo.varinfo, vn, dist) +end diff --git a/src/utils.jl b/src/utils.jl index 4bf652363..9ddeb6247 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -225,92 +225,127 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -vectorize(d, r) = vectorize(r) -vectorize(r::Real) = [r] -vectorize(r::AbstractArray{<:Real}) = copy(vec(r)) -vectorize(r::Cholesky) = copy(vec(r.UL)) +# Useful transformation going from the flattened representation. +struct FromVec{Size} <: Bijectors.Bijector + size::Size +end + +FromVec(x::Union{Real,AbstractArray}) = FromVec(size(x)) -# NOTE: -# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real. -# However here we would like the result to be specifric type, e.g. Array{Dual{4,Float64}, 2}, -# otherwise we will have error for MatrixDistribution. -# Note this is not the case for MultivariateDistribution so I guess this might be lack of -# support for some types related to matrices (like PDMat). +# TODO: Should we materialize the `reshape`? +(f::FromVec)(x) = reshape(x, f.size) +(f::FromVec{Tuple{}})(x) = only(x) +# TODO: Specialize for `Tuple{<:Any}` since this correspond to a `Vector`. + +Bijectors.with_logabsdet_jacobian(f::FromVec, x) = (f(x), 0) +# We want to use the inverse of `FromVec` so it preserves the size information. +Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:FromVec}, x) = (tovec(x), 0) + +struct ToChol <: Bijectors.Bijector + uplo::Char +end + +Bijectors.with_logabsdet_jacobian(f::ToChol, x) = (Cholesky(Matrix(x), f.uplo, 0), 0) +Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = (y.UL, 0) """ - reconstruct([f, ]dist, val) + from_vec_transform(x) -Reconstruct `val` so that it's compatible with `dist`. +Return the transformation from the vector representation of `x` to original representation. +""" +from_vec_transform(x::Union{Real,AbstractArray}) = from_vec_transform_for_size(size(x)) +from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ FromVec(size(C.UL)) -If `f` is also provided, the reconstruct value will be -such that `f(reconstruct_val)` is compatible with `dist`. """ -reconstruct(f, dist, val) = reconstruct(dist, val) + from_vec_transform_for_size(sz::Tuple) -# No-op versions. -reconstruct(::UnivariateDistribution, val::Real) = val -reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) -reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) -function reconstruct( - ::Distribution{ArrayLikeVariate{N}}, val::AbstractArray{<:Real,N} -) where {N} - return copy(val) -end -reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val) +Return the transformation from the vector representation of a realization of size `sz` to original representation. +""" +from_vec_transform_for_size(sz::Tuple) = FromVec(sz) +from_vec_transform_for_size(::Tuple{()}) = FromVec(()) +from_vec_transform_for_size(::Tuple{<:Any}) = identity -function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) - return reconstruct(dist, Matrix(reshape(val, size(dist)))) -end -function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real}) - return Cholesky(val, dist.uplo, 0) -end -reconstruct(::LKJCholesky, val::Cholesky) = val +""" + from_vec_transform(dist::Distribution) -function reconstruct( - ::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector -) - return copy(val) -end +Return the transformation from the vector representation of a realization from +distribution `dist` to the original representation compatible with `dist`. +""" +from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) +from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ FromVec(size(dist)) -function reconstruct( - ::Inverse{Bijectors.PDVecBijector}, ::MatrixDistribution, val::AbstractVector -) - return copy(val) -end +""" + from_vec_transform(f, size::Tuple) -# TODO: Implement no-op `reconstruct` for general array variates. +Return the transformation from the vector representation of a realization of size `size` to original representation. -reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) -reconstruct(::Tuple{}, val::AbstractVector) = val[1] -reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) -reconstruct(s::NTuple{2}, val::AbstractVector) = reshape(copy(val), s) -function reconstruct!(r, d::Distribution, val::AbstractVector) - return reconstruct!(r, d, val) -end -function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector) - r .= val - return r +This is useful when the transformation alters the size of the realization, in which case we need to account for the +size of the realization after pushed through the transformation. +""" +from_vec_transform(f, sz) = from_vec_transform_for_size(Bijectors.output_size(f, sz)) + +""" + from_linked_vec_transform(dist::Distribution) + +Return the transformation from the unconstrained vector to the constrained +realization of distribution `dist`. + +By default, this is just `invlink_transform(dist) ∘ from_vec_transform(dist)`. + +See also: [`DynamicPPL.invlink_transform`](@ref), [`DynamicPPL.from_vec_transform`](@ref). +""" +function from_linked_vec_transform(dist::Distribution) + f_invlink = invlink_transform(dist) + f_vec = from_vec_transform(inverse(f_invlink), size(dist)) + return f_invlink ∘ f_vec end -function reconstruct(d::Distribution, val::AbstractVector, n::Int) - return reconstruct(size(d), val, n) + +# Specializations that circumvent the `from_vec_transform` machinery. +function from_linked_vec_transform(dist::LKJCholesky) + return inverse(Bijectors.VecCholeskyBijector(dist.uplo)) end -function reconstruct(::Tuple{}, val::AbstractVector, n::Int) +from_linked_vec_transform(::LKJ) = inverse(Bijectors.VecCorrBijector()) + +""" + to_vec_transform(x) + +Return the transformation from the original representation of `x` to the vector +representation. +""" +to_vec_transform(x) = inverse(from_vec_transform(x)) + +""" + to_linked_vec_transform(dist) + +Return the transformation from the constrained realization of distribution `dist` +to the unconstrained vector. +""" +to_linked_vec_transform(x) = inverse(from_linked_vec_transform(x)) + +# FIXME: When given a `LowerTriangular`, `VarInfo` still stores the full matrix +# flattened, while using `tovec` below flattenes only the necessary entries. +# => Need to either fix how `VarInfo` does things, i.e. use `tovec` everywhere, +# or fix `tovec` to flatten the full matrix instead of using `Bijectors.triu_to_vec`. +tovec(x::Real) = [x] +tovec(x::AbstractArray) = vec(x) +tovec(C::Cholesky) = tovec(Matrix(C.UL)) + +""" + recombine(dist::Union{UnivariateDistribution,MultivariateDistribution}, vals::AbstractVector, n::Int) + +Recombine `vals`, representing a batch of samples from `dist`, so that it's a compatible with `dist`. + +!!! warning + This only supports `UnivariateDistribution` and `MultivariateDistribution`, which are the only two + distribution types which are allowed on the right-hand side of a `.~` statement in a model. +""" +function recombine(::UnivariateDistribution, val::AbstractVector, ::Int) + # This is just a no-op, since we're trying to convert a vector into a vector. return copy(val) end -function reconstruct(s::NTuple{1}, val::AbstractVector, n::Int) - return copy(reshape(val, s[1], n)) -end -function reconstruct(s::NTuple{2}, val::AbstractVector, n::Int) - tmp = reshape(val, s..., n) - orig = [tmp[:, :, i] for i in 1:n] - return orig -end -function reconstruct!(r, d::Distribution, val::AbstractVector, n::Int) - return reconstruct!(r, d, val, n) -end -function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector, n::Int) - r .= val - return r +function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) + # Here `val` is of the length `length(d) * n` and so we need to reshape it. + return copy(reshape(val, length(d), n)) end # Uniform random numbers with range 4 for robust initializations @@ -360,8 +395,13 @@ end ####################### # Convenience methods # ####################### -collectmaybe(x) = x -collectmaybe(x::Base.AbstractSet) = collect(x) +""" + collect_maybe(x) + +Return `x` if `x` is an array, otherwise return `collect(x)`. +""" +collect_maybe(x) = collect(x) +collect_maybe(x::AbstractArray) = x ####################### # BangBang.jl related # diff --git a/src/varinfo.jl b/src/varinfo.jl index 903789325..2670397d9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -113,21 +113,44 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ transformation(vi::VarInfo) = DynamicTransformation() function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) - md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) + md = replace_values(old_vi.metadata, Val(getspace(spl)), x) return VarInfo( md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) ) end -function VarInfo( +""" + untyped_varinfo([rng, ]model[, sampler, context]) + +Return an untyped `VarInfo` instance for the model `model`. +""" +function untyped_varinfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) varinfo = VarInfo() - model(rng, varinfo, sampler, context) - return TypedVarInfo(varinfo) + return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) +end +function untyped_varinfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) + return untyped_varinfo(Random.default_rng(), model, args...) +end + +""" + typed_varinfo([rng, ]model[, sampler, context]) + +Return a typed `VarInfo` instance for the model `model`. +""" +typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) + +function VarInfo( + rng::Random.AbstractRNG, + model::Model, + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return typed_varinfo(rng, model, sampler, context) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) @@ -142,7 +165,8 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext end # TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573 -function newmetadata(metadata::Metadata, space, x) +replace_values(metadata::Metadata, space, x) = replace_values(metadata, x) +function replace_values(metadata::Metadata, x) return Metadata( metadata.idcs, metadata.vns, @@ -155,7 +179,7 @@ function newmetadata(metadata::Metadata, space, x) ) end -@generated function newmetadata( +@generated function replace_values( metadata::NamedTuple{names}, ::Val{space}, x ) where {names,space} exprs = [] @@ -164,21 +188,7 @@ end mdf = :(metadata.$f) if inspace(f, space) || length(space) == 0 len = :(sum(length, $mdf.ranges)) - push!( - exprs, - :( - $f = Metadata( - $mdf.idcs, - $mdf.vns, - $mdf.ranges, - x[($offset + 1):($offset + $len)], - $mdf.dists, - $mdf.gids, - $mdf.orders, - $mdf.flags, - ) - ), - ) + push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) offset = :($offset + $len) else push!(exprs, :($f = $mdf)) @@ -400,8 +410,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) push!(vns, vn) if vn in vns_left && vn in vns_right # `vals`: only valid if they're the length. - vals_left = getval(metadata_left, vn) - vals_right = getval(metadata_right, vn) + vals_left = getindex_internal(metadata_left, vn) + vals_right = getindex_internal(metadata_right, vn) @assert length(vals_left) == length(vals_right) append!(vals, vals_right) # `ranges` @@ -422,7 +432,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) elseif vn in vns_left # Just extract the metadata from `metadata_left`. # `vals` - vals_left = getval(metadata_left, vn) + vals_left = getindex_internal(metadata_left, vn) append!(vals, vals_left) # `ranges` r = (offset + 1):(offset + length(vals_left)) @@ -440,7 +450,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) else # Just extract the metadata from `metadata_right`. # `vals` - vals_right = getval(metadata_right, vn) + vals_right = getindex_internal(metadata_right, vn) append!(vals, vals_right) # `ranges` r = (offset + 1):(offset + length(vals_right)) @@ -463,25 +473,12 @@ end const VarView = Union{Int,UnitRange,Vector{Int}} -""" - getval(vi::UntypedVarInfo, vview::Union{Int, UnitRange, Vector{Int}}) - -Return a view `vi.vals[vview]`. -""" -getval(vi::UntypedVarInfo, vview::VarView) = view(vi.metadata.vals, vview) - """ setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) Set the value of `vi.vals[vview]` to `val`. """ setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val -function setval!(vi::UntypedVarInfo, val, vview::Vector{UnitRange}) - if length(vview) > 0 - vi.metadata.vals[[i for arr in vview for i in arr]] = val - end - return val -end """ getmetadata(vi::VarInfo, vn::VarName) @@ -532,15 +529,16 @@ Return the distribution from which `vn` was sampled in `vi`. getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] -""" - getval(vi::VarInfo, vn::VarName) - -Return the value(s) of `vn`. +getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) +# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, +# since then we might be returning a `SubArray` rather than an `Array`, which is typically +# what a bijector would result in, even if the input is a view (`SubArray`). +# TODO(torfjelde): An alternative is to implement `view` directly instead. +getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) -The values may or may not be transformed to Euclidean space. -""" -getval(vi::VarInfo, vn::VarName) = getval(getmetadata(vi, vn), vn) -getval(md::Metadata, vn::VarName) = view(md.vals, getrange(md, vn)) +function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) + return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) +end """ setval!(vi::VarInfo, val, vn::VarName) @@ -554,18 +552,9 @@ function setval!(md::Metadata, val::AbstractVector, vn::VarName) return md.vals[getrange(md, vn)] = val end function setval!(md::Metadata, val, vn::VarName) - return md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val) + return md.vals[getrange(md, vn)] = tovec(val) end -""" - getval(vi::VarInfo, vns::Vector{<:VarName}) - -Return the value(s) of `vns`. - -The values may or may not be transformed to Euclidean space. -""" -getval(vi::VarInfo, vns::Vector{<:VarName}) = mapreduce(Base.Fix1(getval, vi), vcat, vns) - """ getall(vi::VarInfo) @@ -573,12 +562,14 @@ Return the values of all the variables in `vi`. The values may or may not be transformed to Euclidean space. """ -getall(vi::UntypedVarInfo) = getall(vi.metadata) +getall(vi::VarInfo) = getall(vi.metadata) # NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. # See for example https://github.com/JuliaLang/julia/pull/46381. getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata)) function getall(md::Metadata) - return mapreduce(Base.Fix1(getval, md), vcat, md.vns; init=similar(md.vals, 0)) + return mapreduce( + Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) + ) end """ @@ -588,12 +579,13 @@ Set the values of all the variables in `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -function setall!(vi::UntypedVarInfo, val) - for r in vi.metadata.ranges - vi.metadata.vals[r] .= val[r] +setall!(vi::VarInfo, val) = _setall!(vi.metadata, val) + +function _setall!(metadata::Metadata, val) + for r in metadata.ranges + metadata.vals[r] .= val[r] end end -setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) @generated function _setall!(metadata::NamedTuple{names}, val) where {names} expr = Expr(:block) start = :(1) @@ -614,13 +606,17 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(getmetadata(vi, vn), trans, vn) + return vi +end +function settrans!!(metadata::Metadata, trans::Bool, vn::VarName) if trans - set_flag!(vi, vn, "trans") + set_flag!(metadata, vn, "trans") else - unset_flag!(vi, vn, "trans") + unset_flag!(metadata, vn, "trans") end - return vi + return metadata end function settrans!!(vi::VarInfo, trans::Bool) @@ -754,7 +750,7 @@ end @inline function _getranges(vi::VarInfo, s::Selector, space) return _getranges(vi, _getidcs(vi, s, space)) end -@inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) +@inline function _getranges(vi::VarInfo, idcs::Vector{Int}) return mapreduce(i -> vi.metadata.ranges[i], vcat, idcs; init=Int[]) end @inline _getranges(vi::TypedVarInfo, idcs::NamedTuple) = _getranges(vi.metadata, idcs) @@ -784,7 +780,11 @@ end Set `vn`'s value for `flag` to `true` in `vi`. """ function set_flag!(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true + set_flag!(getmetadata(vi, vn), vn, flag) + return vi +end +function set_flag!(md::Metadata, vn::VarName, flag::String) + return md.flags[flag][getidx(md, vn)] = true end #### @@ -866,7 +866,8 @@ function BangBang.empty!!(vi::VarInfo) reset_num_produce!(vi) return vi end -@inline _empty!(metadata::Metadata) = empty!(metadata) + +_empty!(metadata) = empty!(metadata) @generated function _empty!(metadata::NamedTuple{names}) where {names} expr = Expr(:block) for f in names @@ -875,8 +876,9 @@ end return expr end -# Functions defined only for UntypedVarInfo -Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) +# `keys` +Base.keys(md::Metadata) = md.vns +Base.keys(vi::VarInfo) = keys(vi.metadata) # HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly # on other methods in the codebase which requires `Vector{<:VarName}`. @@ -886,7 +888,7 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] push!(expr.args, :vcat) for n in names - push!(expr.args, :(vi.metadata.$n.vns)) + push!(expr.args, :(keys(vi.metadata.$n))) end return expr @@ -907,7 +909,8 @@ function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end -istrans(vi::VarInfo, vn::VarName) = is_flagged(vi, vn, "trans") +istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) +istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") getlogp(vi::VarInfo) = vi.logp[] @@ -1005,7 +1008,9 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) if ~istrans(vi, vns[1]) for vn in vns dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, link_transform(dist)) + _inner_transform!( + vi, vn, dist, internal_to_linked_internal_transform(vi, vn, dist) + ) settrans!!(vi, true, vn) end else @@ -1033,7 +1038,12 @@ end # Iterate over all `f_vns` and transform for vn in f_vns dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, link_transform(dist)) + _inner_transform!( + vi, + vn, + dist, + internal_to_linked_internal_transform(vi, vn, dist), + ) settrans!!(vi, true, vn) end else @@ -1047,7 +1057,9 @@ end end # R -> X for all variables associated with given sampler -function invlink!!(::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) +function invlink!!( + t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model +) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, spl) return vi @@ -1100,7 +1112,9 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) if istrans(vi, vns[1]) for vn in vns dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, invlink_transform(dist)) + _inner_transform!( + vi, vn, dist, linked_internal_to_internal_transform(vi, vn, dist) + ) settrans!!(vi, false, vn) end else @@ -1128,7 +1142,12 @@ end # Iterate over all `f_vns` and transform for vn in f_vns dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, invlink_transform(dist)) + _inner_transform!( + vi, + vn, + dist, + linked_internal_to_internal_transform(vi, vn, dist), + ) settrans!!(vi, false, vn) end else @@ -1142,9 +1161,12 @@ end end function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) + return _inner_transform!(getmetadata(vi, vn), vi, vn, dist, f) +end + +function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, dist, f) # TODO: Use inplace versions to avoid allocations - y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, getval(vi, vn)) - yvec = vectorize(dist, y) + yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) # Determine the new range. start = first(getrange(vi, vn)) # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. @@ -1160,14 +1182,15 @@ end # an empty iterable for `SampleFromPrior`, so we need to override it here. # This is quite hacky, but seems safer than changing the behavior of `_getvns`. _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) -_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing +_getvns_link(varinfo::VarInfo, spl::SampleFromPrior) = nothing function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) return map(Returns(nothing), varinfo.metadata) end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(varinfo, spl) + return _link(model, varinfo, spl) end + function link( ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, @@ -1179,30 +1202,34 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) end -function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) +function _link(model::Model, varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _link_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(varinfo::TypedVarInfo, spl::AbstractSampler) +function _link(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) md = _link_metadata_namedtuple!( - varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) + model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @generated function _link_metadata_namedtuple!( - varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} + model::Model, + varinfo::VarInfo, + metadata::NamedTuple{names}, + vns::NamedTuple, + ::Val{space}, ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f))) + push!(vals.args, :(_link_metadata!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1210,7 +1237,7 @@ end return :(NamedTuple{$names}($vals)) end -function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) +function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1222,12 +1249,12 @@ function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) end # Transform to constrained space. - x = getval(varinfo, vn) - dist = getdist(varinfo, vn) - f = link_transform(dist) - y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, x) + x = getindex_internal(metadata, vn) + dist = getdist(metadata, vn) + f = internal_to_linked_internal_transform(varinfo, vn, dist) + y, logjac = with_logabsdet_jacobian(f, x) # Vectorize value. - yvec = vectorize(dist, y) + yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. acclogp!!(varinfo, -logjac) # Mark as no longer transformed. @@ -1261,7 +1288,7 @@ end function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) - return _invlink(varinfo, spl) + return _invlink(model, varinfo, spl) end function invlink( ::DynamicTransformation, @@ -1274,30 +1301,34 @@ function invlink( return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, spl, model) end -function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::VarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), + _invlink_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) +function _invlink(model::Model, varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) md = _invlink_metadata_namedtuple!( - varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) + model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @generated function _invlink_metadata_namedtuple!( - varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} + model::Model, + varinfo::VarInfo, + metadata::NamedTuple{names}, + vns::NamedTuple, + ::Val{space}, ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f))) + push!(vals.args, :(_invlink_metadata!(model, varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1305,7 +1336,7 @@ end return :(NamedTuple{$names}($vals)) end -function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) +function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. @@ -1318,12 +1349,12 @@ function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) end # Transform to constrained space. - y = getval(varinfo, vn) + y = getindex_internal(varinfo, vn) dist = getdist(varinfo, vn) - f = invlink_transform(dist) - x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y) + f = from_linked_internal_transform(varinfo, vn, dist) + x, logjac = with_logabsdet_jacobian(f, y) # Vectorize value. - xvec = vectorize(dist, x) + xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. acclogp!!(varinfo, -logjac) # Mark as no longer transformed. @@ -1420,34 +1451,24 @@ end getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getval(vi, vn) - return maybe_invlink_and_reconstruct(vi, vn, dist, val) + val = getindex_internal(vi, vn) + return from_maybe_linked_internal(vi, vn, dist, val) end + function getindex(vi::VarInfo, vns::Vector{<:VarName}) - # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases - # such as `x .~ [Normal(), Exponential()]`. - # BUT we also can't fix this here because this will lead to "incorrect" - # behavior if `vns` arose from something like `x .~ MvNormal(zeros(2), I)`, - # where by "incorrect" we mean there exists pieces of code expecting this behavior. - return getindex(vi, vns, getdist(vi, first(vns))) + vals_linked = mapreduce(vcat, vns) do vn + getindex(vi, vn) + end + # HACK: I don't like this. + dist = getdist(vi, vns[1]) + return recombine(dist, vals_linked, length(vns)) end function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn getindex(vi, vn, dist) end - return reconstruct(dist, vals_linked, length(vns)) -end - -getindex_raw(vi::VarInfo, vn::VarName) = getindex_raw(vi, vn, getdist(vi, vn)) -function getindex_raw(vi::VarInfo, vn::VarName, dist::Distribution) - return reconstruct(dist, getval(vi, vn)) -end -function getindex_raw(vi::VarInfo, vns::Vector{<:VarName}) - return getindex_raw(vi, vns, getdist(vi, first(vns))) -end -function getindex_raw(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) - return reconstruct(dist, getval(vi, vns), length(vns)) + return recombine(dist, vals_linked, length(vns)) end """ @@ -1457,7 +1478,7 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ -getindex(vi::UntypedVarInfo, spl::Sampler) = copy(getval(vi, _getranges(vi, spl))) +getindex(vi::VarInfo, spl::Sampler) = copy(getindex_internal(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple ranges = _getranges(vi, spl) @@ -1530,16 +1551,19 @@ end return map(vn -> vi[vn], f_vns) end +haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) + """ haskey(vi::VarInfo, vn::VarName) Check whether `vn` has been sampled in `vi`. """ -haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn).idcs, vn) +haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) function haskey(vi::TypedVarInfo, vn::VarName) - metadata = vi.metadata - Tmeta = typeof(metadata) - return getsym(vn) in fieldnames(Tmeta) && haskey(getmetadata(vi, vn).idcs, vn) + md_haskey = map(vi.metadata) do metadata + haskey(metadata, vn) + end + return any(md_haskey) end function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) @@ -1564,7 +1588,7 @@ const _MAX_VARS_SHOWN = 4 function _show_varnames(io::IO, vi) md = vi.metadata - vns = md.vns + vns = keys(md) vns_by_name = Dict{Symbol,Vector{VarName}}() for vn in vns @@ -1599,9 +1623,14 @@ function BangBang.push!!( @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end - val = vectorize(dist, r) - meta = getmetadata(vi, vn) + push!(meta, vn, r, dist, gidset, get_num_produce(vi)) + + return vi +end + +function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) + val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) l = length(meta.vals) @@ -1610,11 +1639,11 @@ function BangBang.push!!( append!(meta.vals, val) push!(meta.dists, dist) push!(meta.gids, gidset) - push!(meta.orders, get_num_produce(vi)) + push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) - return vi + return meta end """ @@ -1624,12 +1653,13 @@ Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `obse statements run before sampling `vn`. """ function setorder!(vi::VarInfo, vn::VarName, index::Int) - metadata = getmetadata(vi, vn) - if metadata.orders[metadata.idcs[vn]] != index - metadata.orders[metadata.idcs[vn]] = index - end + setorder!(getmetadata(vi, vn), vn, index) return vi end +function setorder!(metadata::Metadata, vn::VarName, index::Int) + metadata.orders[metadata.idcs[vn]] = index + return metadata +end """ getorder(vi::VarInfo, vn::VarName) @@ -1662,9 +1692,13 @@ end Set `vn`'s value for `flag` to `false` in `vi`. """ function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false + unset_flag!(getmetadata(vi, vn), vn, flag) return vi end +function unset_flag!(metadata::Metadata, vn::VarName, flag::String) + metadata.flags[flag][getidx(metadata, vn)] = false + return metadata +end """ set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) @@ -1740,7 +1774,7 @@ end Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) - keys_strings = map(string, collectmaybe(keys)) + keys_strings = map(string, collect_maybe(keys)) num_indices_seen = 0 for vn in Base.keys(vi) @@ -1762,7 +1796,7 @@ function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) end function _apply!(kernel!, vi::TypedVarInfo, values, keys) - return _typed_apply!(kernel!, vi, vi.metadata, values, collectmaybe(keys)) + return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) end @generated function _typed_apply!( @@ -1798,7 +1832,7 @@ end end function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) - string_vns = map(string, collectmaybe(Base.keys(vi))) + string_vns = map(string, collect_maybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key !any(Base.Fix2(subsumes_string, key), string_vns) @@ -2011,7 +2045,39 @@ end function values_from_metadata(md::Metadata) return ( - vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for - vn in md.vns + # `copy` to avoid accidentally mutation of internal representation. + vn => copy( + from_internal_transform(md, vn, getdist(md, vn))(getindex_internal(md, vn)) + ) for vn in md.vns ) end + +# Transforming from internal representation to distribution representation. +# Without `dist` argument: base on `dist` extracted from self. +function from_internal_transform(vi::VarInfo, vn::VarName) + return from_internal_transform(getmetadata(vi, vn), vn) +end +function from_internal_transform(md::Metadata, vn::VarName) + return from_internal_transform(md, vn, getdist(md, vn)) +end +# With both `vn` and `dist` arguments: base on provided `dist`. +function from_internal_transform(vi::VarInfo, vn::VarName, dist) + return from_internal_transform(getmetadata(vi, vn), vn, dist) +end +from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) + +# Without `dist` argument: base on `dist` extracted from self. +function from_linked_internal_transform(vi::VarInfo, vn::VarName) + return from_linked_internal_transform(getmetadata(vi, vn), vn) +end +function from_linked_internal_transform(md::Metadata, vn::VarName) + return from_linked_internal_transform(md, vn, getdist(md, vn)) +end +# With both `vn` and `dist` arguments: base on provided `dist`. +function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist) + # Dispatch to metadata in case this alters the behavior. + return from_linked_internal_transform(getmetadata(vi, vn), vn, dist) +end +function from_linked_internal_transform(::Metadata, ::VarName, dist) + return from_linked_vec_transform(dist) +end diff --git a/test/linking.jl b/test/linking.jl index 06f6fb6d6..d424a9c2d 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -44,7 +44,9 @@ function Distributions._logpdf(::MyMatrixDistribution, x::AbstractMatrix{<:Real} end # Skip reconstruction in the inverse-map since it's no longer needed. -DynamicPPL.reconstruct(::TrilFromVec, ::MyMatrixDistribution, x::AbstractVector{<:Real}) = x +function DynamicPPL.from_linked_vec_transform(dist::MyMatrixDistribution) + return TrilFromVec((dist.dim, dist.dim)) +end # Specify the link-transform to use. Bijectors.bijector(dist::MyMatrixDistribution) = TrilToVec((dist.dim, dist.dim)) diff --git a/test/model.jl b/test/model.jl index c8fdf0202..60a8d2461 100644 --- a/test/model.jl +++ b/test/model.jl @@ -353,28 +353,32 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end - @testset "Type stability of models" begin - models_to_test = [ - DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) - ] - @testset "$(model.f)" for model in models_to_test - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = filter( - is_typed_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test (@inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())); - true) - - varinfo_linked = DynamicPPL.link(varinfo, model) - @test ( - @inferred( - DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()) - ); - true + if VERSION >= v"1.8" + @testset "Type stability of models" begin + models_to_test = [ + DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) + ] + @testset "$(model.f)" for model in models_to_test + vns = DynamicPPL.TestUtils.varnames(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = filter( + is_typed_varinfo, + DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @test ( + @inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())); + true + ) + + varinfo_linked = DynamicPPL.link(varinfo, model) + @test ( + @inferred( + DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()) + ); + true + ) + end end end end @@ -410,4 +414,25 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true # confused and call it the way you are meant to call `a_model`. @test_throws MethodError instance(1.0) end + + @testset "Product distribution with changing support" begin + @model function product_dirichlet() + return x ~ product_distribution(fill(Dirichlet(ones(4)), 2, 3)) + end + model = product_dirichlet() + + varinfos = [ + DynamicPPL.untyped_varinfo(model), + DynamicPPL.typed_varinfo(model), + DynamicPPL.typed_simple_varinfo(model), + DynamicPPL.untyped_simple_varinfo(model), + ] + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + varinfo_linked = DynamicPPL.link(varinfo, model) + varinfo_linked_result = last( + DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) + ) + @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) + end + end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 869fb82b3..5ce112941 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -100,7 +100,8 @@ # Should result in same values. @test all( - DynamicPPL.getindex_raw(vi_invlinked, vn) ≈ get(values_constrained, vn) for + DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ + DynamicPPL.tovec(get(values_constrained, vn)) for vn in DynamicPPL.TestUtils.varnames(model) ) end @@ -251,8 +252,11 @@ model, deepcopy(vi_linked), DefaultContext() ) - @test DynamicPPL.getindex_raw(vi_linked, @varname(s)) ≠ retval.s # `s` is unconstrained in original - @test DynamicPPL.getindex_raw(vi_linked_result, @varname(s)) == retval.s # `s` is constrained in result + @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ + DynamicPPL.tovec(retval.s) # `s` is unconstrained in original + @test DynamicPPL.tovec( + DynamicPPL.getindex_internal(vi_linked_result, @varname(s)) + ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result # `m` should not be transformed. @test vi_linked[@varname(m)] == retval.m @@ -263,9 +267,10 @@ model, retval.s, retval.m ) - # Realizations in `vi_linked` should all be equal to the unconstrained realization. - @test DynamicPPL.getindex_raw(vi_linked, @varname(s)) ≈ retval_unconstrained.s - @test DynamicPPL.getindex_raw(vi_linked, @varname(m)) ≈ retval_unconstrained.m + @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ + DynamicPPL.tovec(retval_unconstrained.s) + @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ + DynamicPPL.tovec(retval_unconstrained.m) # The resulting varinfo should hold the correct logp. lp = getlogp(vi_linked_result) diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 30408e598..f1d805505 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -14,7 +14,7 @@ elseif is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = rand(dist) - vi[vn] = vectorize(dist, r) + vi[vn] = DynamicPPL.tovec(r) setorder!(vi, vn, get_num_produce(vi)) r else diff --git a/test/utils.jl b/test/utils.jl index 1fcf09ef1..3f435dca4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -43,9 +43,9 @@ @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing end - @testset "vectorize" begin + @testset "tovec" begin dist = LKJCholesky(2, 1) x = rand(dist) - @test vectorize(dist, x) == vec(x.UL) + @test DynamicPPL.tovec(x) == vec(x.UL) end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 12387f6a7..6a3d8d2bc 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -309,7 +309,8 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `TypedVarInfo` @@ -317,7 +318,8 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vi = DynamicPPL.settrans!!(vi, true, vn) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ### `SimpleVarInfo` @@ -325,14 +327,16 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) # Sample in unconstrained space. vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - x = Bijectors.invlink(dist, DynamicPPL.getindex_raw(vi, vn)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @@ -413,6 +417,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) else DynamicPPL.link(varinfo, model) end + for vn in keys(varinfo) + @test DynamicPPL.istrans(varinfo_linked, vn) + end @test length(varinfo[:]) > length(varinfo_linked[:]) varinfo_linked_unflattened = DynamicPPL.unflatten( varinfo_linked, varinfo_linked[:] @@ -655,7 +662,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Should only get the variables subsumed by `@varname(s)`. @test varinfo[spl] == - mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s) + mapreduce(Base.Fix1(DynamicPPL.getindex_internal, varinfo), vcat, vns_s) # `link` varinfo_linked = DynamicPPL.link(varinfo, spl, model)