Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Jan 31, 2024
1 parent 95dc8e3 commit 8930f9c
Showing 1 changed file with 43 additions and 45 deletions.
88 changes: 43 additions & 45 deletions docs/src/internals/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ For example, consider the following model:
```julia
@model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, s)
return m ~ Normal(0, s)
end
```

Expand All @@ -18,6 +18,7 @@ Here we have two variables `s` and `m`, where `s` is constrained to be positive,
For certain inference methods, it's necessary / much more convenient to work with an equivalent model to `demo` but where all the variables can take any real values (they're "unconstrained").

!!! note

We write "unconstrained" with quotes because there are many ways to transform a constrained variable to an unconstrained one, *and* DynamicPPL can work with a much broader class of bijective transformations of variables, not just ones that go to the entire real line. But for MCMC, unconstraining is the most common transformation so we'll stick with that terminology.

For a large family of constraints encoucntered in practice, it is indeed possible to transform a (partially) contrained model to a completely unconstrained one in such a way that sampling in the unconstrained space is equivalent to sampling in the constrained space.
Expand All @@ -30,7 +31,7 @@ For example, the above model could be transformed into (the following psuedo-cod
@model function demo()
log_s ~ log(InverseGamma(2, 3))
s = exp(log_s)
m ~ Normal(0, s)
return m ~ Normal(0, s)
end
```

Expand All @@ -51,12 +52,12 @@ Below we'll see how this is done.
## What do we need?

There are two aspects to transforming from the internal representation of a variable in a `varinfo` to the representation wanted in the model:
1. Different implementations of [`AbstractVarInfo`](@ref) represent realizations of a model in different ways internally, so we need to transform from this internal representation to the desired representation in the model. For example,
- [`VarInfo`](@ref) represents a realization of a model as in a "flattened" / vector representation, regardless of form of the variable in the model.
- [`SimpleVarInfo`](@ref) represents a realization of a model exactly as in the model (unless it has been transformed; we'll get to that later).
2. We need the ability to transform from "constrained space" to "unconstrained space", as we saw in the previous section.


1. Different implementations of [`AbstractVarInfo`](@ref) represent realizations of a model in different ways internally, so we need to transform from this internal representation to the desired representation in the model. For example,

+ [`VarInfo`](@ref) represents a realization of a model as in a "flattened" / vector representation, regardless of form of the variable in the model.
+ [`SimpleVarInfo`](@ref) represents a realization of a model exactly as in the model (unless it has been transformed; we'll get to that later).
2. We need the ability to transform from "constrained space" to "unconstrained space", as we saw in the previous section.

## Working example

Expand All @@ -70,11 +71,13 @@ using DynamicPPL, Distributions
`LKJCholesky` is a `LKJ(2, 1.0)` distribution, a distribution over correlation matrices (covariance matrices but with unit diagonal), but working directly with the Cholesky factorization of the correlation matrix rather than the correlation matrix itself (this is more numerically stable and computationally efficient).

!!! note

This is a particularly "annoying" case because the return-value is not a simple `Real` or `AbstractArray{<:Real}`, but rather a `LineraAlgebra.Cholesky` object which wraps a triangular matrix (whether it's upper- or lower-triangular depends on the instance).

As mentioned, some implementations of `AbstractVarInfo`, e.g. [`VarInfo`](@ref), works with a "flattened" / vector representation of a variable, and so in this case we need two transformations:
1. From the `Cholesky` object to a vector representation.
2. From the `Cholesky` object to an "unconstrained" / linked vector representation.

1. From the `Cholesky` object to a vector representation.
2. From the `Cholesky` object to an "unconstrained" / linked vector representation.

And similarly, we'll need the inverses of these transformations.

Expand All @@ -88,10 +91,12 @@ DynamicPPL.from_internal_transform
```

These methods allows us to extract the internal-to-model transformation function depending on the `varinfo`, the variable, and the distribution of the variable:
- `varinfo` + `vn` defines the internal representation of the variable.
- `dist` defines the representation expected within the model scope.

- `varinfo` + `vn` defines the internal representation of the variable.
- `dist` defines the representation expected within the model scope.

!!! note

If `vn` is not present in `varinfo`, then the internal representation is fully determined by `varinfo` alone. This is used when we're about to add a new variable to the `varinfo` and need to know how to represent it internally.

Continuing from the example above, we can inspect the internal representation of `x` in `demo_lkj` with [`VarInfo`](@ref) using [`DynamicPPL.getindex_internal`](@ref):
Expand All @@ -104,9 +109,7 @@ x_internal = DynamicPPL.getindex_internal(varinfo, @varname(x))

```@example transformations-internal
f_from_internal = DynamicPPL.from_internal_transform(
varinfo,
@varname(x),
LKJCholesky(2, 1.0)
varinfo, @varname(x), LKJCholesky(2, 1.0)
)
f_from_internal(x_internal)
```
Expand All @@ -120,11 +123,7 @@ x_model = varinfo[@varname(x)]
Similarly, we can go from the model representation to the internal representation:

```@example transformations-internal
f_to_internal = DynamicPPL.to_internal_transform(
varinfo,
@varname(x),
LKJCholesky(2, 1.0)
)
f_to_internal = DynamicPPL.to_internal_transform(varinfo, @varname(x), LKJCholesky(2, 1.0))
f_to_internal(x_model)
```
Expand All @@ -139,11 +138,7 @@ DynamicPPL.getindex_internal(simple_varinfo, @varname(x))
Here see that the internal representation is exactly the same as the model representation, and so we'd expect `from_internal_transform` to be the `identity` function:

```@example transformations-internal
DynamicPPL.from_internal_transform(
simple_varinfo,
@varname(x),
LKJCholesky(2, 1.0)
)
DynamicPPL.from_internal_transform(simple_varinfo, @varname(x), LKJCholesky(2, 1.0))
```

Great!
Expand All @@ -165,19 +160,15 @@ Continuing from the example above:

```@example transformations-internal
f_to_linked_internal = DynamicPPL.to_linked_internal_transform(
varinfo,
@varname(x),
LKJCholesky(2, 1.0)
varinfo, @varname(x), LKJCholesky(2, 1.0)
)
x_linked_internal = f_to_linked_internal(x_model)
```

```@example transformations-internal
f_from_linked_internal = DynamicPPL.from_linked_internal_transform(
varinfo,
@varname(x),
LKJCholesky(2, 1.0)
varinfo, @varname(x), LKJCholesky(2, 1.0)
)
f_from_linked_internal(x_linked_internal)
Expand Down Expand Up @@ -216,7 +207,7 @@ Unfortunately, this is not possible in general. Consider for example the followi
```@example transformations-internal
@model function demo_dynamic_constraint()
m ~ Normal()
x ~ truncated(Normal(), lower=m)
x ~ truncated(Normal(); lower=m)
return (m=m, x=x)
end
Expand Down Expand Up @@ -266,6 +257,7 @@ first(DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()))
we see that we indeed satisfy the constraint `m < x`, as desired.

!!! warning

One shouldn't be setting variables in a linked `varinfo` willy-nilly directly like this unless one knows that the value will be compatible with the constraints of the model.

The reason for this is that internally in a model evaluation, we construct the transformation from the internal to the model representation based on the *current* realizations in the model! That is, we take the `dist` in a `x ~ dist` expression _at model evaluation time_ and use that to construct the transformation, thus allowing it to change between model evaluations without invalidating the transformation.
Expand All @@ -291,11 +283,13 @@ And so the earlier diagram becomes:
```

!!! note
If the support of `dist` was constant, this would not be necessary since we could just determine the transformation at the time of `varinfo_linked = link(varinfo, model)` and define this as the `from_internal_transform` for all subsequent evaluations. However, since the support of `dist` is *not* constant in general, we need to be able to determine the transformation at the time of the evaluation *and* thus whether we should construct the transformation from the linked internal representation or the non-linked internal representation. This is annoying, but necessary.

If the support of `dist` was constant, this would not be necessary since we could just determine the transformation at the time of `varinfo_linked = link(varinfo, model)` and define this as the `from_internal_transform` for all subsequent evaluations. However, since the support of `dist` is *not* constant in general, we need to be able to determine the transformation at the time of the evaluation *and* thus whether we should construct the transformation from the linked internal representation or the non-linked internal representation. This is annoying, but necessary.

This is also the reason why we have two definitions of `getindex`:
- [`getindex(::AbstractVarInfo, ::VarName, ::Distribution)`](@ref): used internally in model evaluations with the `dist` in a `x ~ dist` expression.
- [`getindex(::AbstractVarInfo, ::VarName)`](@ref): used externally by the user to get the realization of a variable.

- [`getindex(::AbstractVarInfo, ::VarName, ::Distribution)`](@ref): used internally in model evaluations with the `dist` in a `x ~ dist` expression.
- [`getindex(::AbstractVarInfo, ::VarName)`](@ref): used externally by the user to get the realization of a variable.

For `getindex` we have the following diagram:

Expand All @@ -316,8 +310,9 @@ While if `dist` is not provided, we have:
Notice that `dist` is not present here, but otherwise the diagrams are the same.

!!! warning
This does mean that the `getindex(varinfo, varname)` might not be the same as the `getindex(varinfo, varname, dist)` that occurs within a model evaluation! This can be confusing, but as outlined above, we do want to allow the `dist` in a `x ~ dist` expression to "override" whatever transformation `varinfo` might have.

This does mean that the `getindex(varinfo, varname)` might not be the same as the `getindex(varinfo, varname, dist)` that occurs within a model evaluation! This can be confusing, but as outlined above, we do want to allow the `dist` in a `x ~ dist` expression to "override" whatever transformation `varinfo` might have.

## Other functionalities

There are also some additional methods for transforming between representations that are all automatically implemented from [`DynamicPPL.from_internal_transform`](@ref), [`DynamicPPL.from_linked_internal_transform`](@ref) and their siblings, and thus don't need to be implemented manually.
Expand All @@ -341,8 +336,9 @@ DynamicPPL.from_maybe_linked_internal
# Supporting a new distribution

To support a new distribution, one needs to implement for the desired `AbstractVarInfo` the following methods:
- [`DynamicPPL.from_internal_transform`](@ref)
- [`DynamicPPL.from_linked_internal_transform`](@ref)

- [`DynamicPPL.from_internal_transform`](@ref)
- [`DynamicPPL.from_linked_internal_transform`](@ref)

At the time of writing, [`VarInfo`](@ref) is the one that is most commonly used, whose internal representation is always a `Vector`. In this scenario, one can just implemente the following methods instead:

Expand All @@ -354,8 +350,9 @@ DynamicPPL.from_linked_vec_transform(::Distribution)
These are used internally by [`VarInfo`](@ref).

Optionally, if `inverse` of the above is expensive to compute, one can also implement:
- [`DynamicPPL.to_internal_transform`](@ref)
- [`DynamicPPL.to_linked_internal_transform`](@ref)

- [`DynamicPPL.to_internal_transform`](@ref)
- [`DynamicPPL.to_linked_internal_transform`](@ref)

And similarly, there are corresponding to-methods for the `from_*_vec_transform` variants too

Expand All @@ -365,13 +362,14 @@ DynamicPPL.to_linked_vec_transform
```

!!! warning

Whatever the resulting transformation is, it should be invertible, i.e. implement `InverseFunctions.inverse`, and have a well-defined log-abs-det Jacobian, i.e. implement `ChangesOfVariables.with_logabsdet_jacobian`.

# TL;DR

- DynamicPPL.jl has three representations of a variable: the **model representation**, the **internal representation**, and the **linked internal representation**.
- The **model representation** is the representation of the variable as it appears in the model code / is expected by the `dist` on the right-hand-side of the `~` in the model code.
- The **internal representation** is the representation of the variable as it appears in the `varinfo`, which varies between implementations of [`AbstractVarInfo`](@ref), e.g. a `Vector` in [`VarInfo`](@ref). This can be converted to the model representation by [`DynamicPPL.from_internal_transform`](@ref).
- The **linked internal representation** is the representation of the variable as it appears in the `varinfo` after [`link`](@ref)ing. This can be converted to the model representation by [`DynamicPPL.from_linked_internal_transform`](@ref).
- Having separation between *internal* and *linked internal* is necessary because transformations might be constructed at the time of model evaluation, and thus we need to know whether to construct the transformation from the internal representation or the linked internal representation.

- DynamicPPL.jl has three representations of a variable: the **model representation**, the **internal representation**, and the **linked internal representation**.
+ The **model representation** is the representation of the variable as it appears in the model code / is expected by the `dist` on the right-hand-side of the `~` in the model code.
+ The **internal representation** is the representation of the variable as it appears in the `varinfo`, which varies between implementations of [`AbstractVarInfo`](@ref), e.g. a `Vector` in [`VarInfo`](@ref). This can be converted to the model representation by [`DynamicPPL.from_internal_transform`](@ref).
+ The **linked internal representation** is the representation of the variable as it appears in the `varinfo` after [`link`](@ref)ing. This can be converted to the model representation by [`DynamicPPL.from_linked_internal_transform`](@ref).
- Having separation between *internal* and *linked internal* is necessary because transformations might be constructed at the time of model evaluation, and thus we need to know whether to construct the transformation from the internal representation or the linked internal representation.

0 comments on commit 8930f9c

Please sign in to comment.