Skip to content

Commit

Permalink
Rework of how transformations are handled (#575)
Browse files Browse the repository at this point in the history
* initial implementation of VarNameVector

* added some hacky getval and getdist get things to work for VarInfo

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added arbitrary metadata field as discussed

* renamed idcs to varname_to_index

* renamed vns to varnames for VarNameVector

* added keys impl for Metadata

* added push! and update! for VarNameVector

* added getindex_raw! and setindex_raw! for VarNameVector

* added `iterate` and `convert` (for `AbstractDict) impls for `VarNameVector`

* make the key and eltype part of the `VarNameVector` type

* added more tests for VarNameVector

* formatting

* more testing for VarNameVector

* minor changes to some comments

* added a bunch more tests for VarNameVector + several bugfixes in the process

* formatting

* added `similar` implementation for `VarNameVector`

* formatting

* removed debug statement

* made VarInfo slighly more generic wrt. underlying metadata

* fixed incorrect behavior in `keys` for `Metadata`

* minor style changes to VarNameVector tests

* style

* added testing of `update!` with smaller sizes and fixed bug related to this

* formatting

* move functionality related to `push!` for `VarNameVector` into `push!`

* Update src/varnamevector.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* several fixes to make sampling with VarNameVector + initiall tests for
sampling with VarNameVector

* VarInfo + VarNameVector tests for all demo models

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added docs on the design of `VarNameVector`

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added note on `update!`

* further elaboration of the design of `VarInfo` and `VarNameVector`

* more writing improvements

* added docstring to `has_inactive_ranges` and `inactive_ranges_sweep!`

* moved docs on `VarInfo` design to a separate internals section

* writing improvements for internal docs

* further motivation of the design choices made in `VarNameVector`

* improved writing

* VarNameVector is now grown as much as needed

* updated `delete!`

* Significant changes to implementation of `VarNameVector`:
- "delete-by-mark" is now replaced by proper deletion.
- `inactive_ranges` replaced by `num_inactive`, which only keeps track
of the number of inactive entries for a given `VarName.
- `VarNameVector` is now a "grow-as-needed" structure where the
underlying also mimics the order that the user experiences.`

* added `copy` when constructing `VectorVarInfo` from `VarInfo`

* added missing `isempty` impl

* remove impl of `iterate` and instead implemented `pairs` and `values` iterators

* added missing `empty!` for `num_inactive`

* removed redundant `shift_left!` methd

* fixed `delete!` for `VarNameVector`

* added `is_contiguous` as an alterantive to `!has_inactive`

* updates to internal docs

* renamed `sweep_inactive_ranges!` to `contiguify!`

* improvements to internal docs

* more improvements to internal docs

* moved additional methods description in internals to earlier in the doc

* moved internals docs to a separate directory and split into files

* more improvements to internals doc

* formatting

* added tests for `delete!` and fixed reference to old method

* addition to `delete!` test

* added `values_as` impls for `VarNameVector`

* added docs for `replace_valus` and `values_as` for `VarNameVector`

* fixed doctest

* formatting

* temporarily disable doctests so we can build docs

* added missing compat entry for ForwardDiff in docs

* moved some shared code into methods to make things a bit cleaner

* added impl of `merge` for `VarNameVector`

* renamed a few variables in `merge` impl for `VarNameVector`

* forgot to include some changes in previous commit

* added impl of `subset` for `VarNameVector`

* fixed `pairs` impl for `VarNameVector`

* added missing impl of `subset` for `VectorVarInfo`

* added missing impl of `merge_metadata` for `VarNameVector`

* added a bunch of `from_vec_transform` and `tovec` impls to make
`VarNameVector` work with `Cholesky`, etc.

* make default args use `from_vec_transform` rather than `FromVec`

* fixed `values_as` fro `VarInfo` with `VarNameVector` as `metadata`

* fixed impl of `getindex_raw` when using integer index for `VarNameVector`

* added tests for `getindex` with `Int` index for `VarNameVector`

* fix for `setindex!` and `setindex_raw!` for `VarNameVector`

* introduction of `from_vec_transform` and `tovec` and its usage in `VarInfo`

* moved definition of `is_splat_symbol` to the file where it's used

* added `VarInfo` constructor with vector input for `VectorVarInfo`

* make `extract_priors` take the `rng` as an argument

* added `replace_values` for `Metadata`

* make link and invlink act on the `metadata` field for `VarInfo` +
implementations of these for `Metadata` and `VarNameVector`

* added temporary defs of `with_logabsdet_jacobian` and `inverse` for
`transpose` and `Bijectors.vec_to_triu`

* added invlink_with_logpdf overload for `ThreadSafeVarInfo`

* added `is_transformed` field to `VarNameVector`

* removed unnecessary defintions of `with_logabsdet_jacobian` and
`inverse` for `transpose`

* fixed issue where we were storing the wrong transformations in `VarNameVector`

* make sure `extract_priors` doesn't mutate the `varinfo`

* updated `similar` for `VarNameVector` and fixed `invlink` for `VarNameVector`

* added handling of `is_transformed` in `merge` for `VarNameVector`

* removed unnecesasry `deepcopy` from outer `link`

* updated `push!` to also `push!` on `is_transformed`

* skip tests for mutating linking when using VarNameVector

* use same projection for `Cholesky` in `VarNameVector` as in `VarInfo`

* fixed `settrans!!` for `VarInfo` with `VarNameVector`

* fixed bug in `set_flag!`

* fixed another typo

* fixed return values of `settrans!!`

* updated static transformation tests

* Update test/simple_varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* removed unnecessary impl of `extract_priors`

* make `short_varinfo_name` of `TypedVarInfo` a bit more informative

* moved impl of `has_varnamevector` for `ThreadSafeVarInfo`

* added back `extract_priors` impl as we do need it

* forgot to include tests for `VarNameVector` in `runtests.jl`

* fix for `relax_container_types` in `test/varnamevector.jl`

* fixed `need_transforms_relaxation`

* updated some tests to not refer directly to `FromVec`

* introduce `from_internal_transform` and its siblings

* remove `with_logabsdet_jacobian_and_reconstruct` in favour of
`with_logabsdet_jacobian` with `from_linked_internal_transform`, etc.

* added `internal_to_linked_internal_transform` + fixed a few bugs in
the linking as a resultt

* added `linked_internal_to_internal_transform` as a complement to `interanl_to_linked_interanl_transform`

* fixed bugs in `invlink` for `VarInfo` using `linked_internal_to_internal_transform`

* more work on removing calls to `reconstruct`

* removed redundant comment

* added `from_linked_vec_transform` specialization for `LKJ`

* more work on removing references to `reconstruct`

* added `copy` in `values_from_metadata` to preserve behavior and avoid
refs to internal representation

* remove `reconstruct_and_link` and `invlink_and_reconstruct`

* replaced references to `link_and_reconstruct` and `invlink_and_reconstruct`

* introduced `recombine` and replaced calls to `reconstruct` with `n` samples

* completely removed `reconstruct`

* renamed `maybe_reconstruct_and_link` to `to_maybe_linked_internal` and
`maybe_invlink_and_reconstruct` to `from_maybe_linked_internal`

* added impls of `from_*_internal_transform` for `ThreadSafeVarInfo`

* removed `reconstruct` from docs and from exports

* renamed `getval` to `getindex_internal` and made `dist` an optional
argument for all the transform-related methods

* updated docs + added description of how internals of transforms work

* added a bunch of illustrations for the transforms docs + dot files used to generated

* temporarily removed `VarNameVector` completely

* formatting

* Update docs/src/internals/transformations.md

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update docs/src/internals/transformations.md

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* removed refs to VectorVarInfo

* added impls of `from_internal_transform` for `ThreadSafeVarInfo`

* reverted accidental removal of old `VarInfo` constructor

* fixed incorrect `recombine` call

* removed undefined refs to `VarNameVector` stuff in `setup_varinfos`

* bump minior version because Turing breaks

* fix: was using `from_linked_internal_transform` in
`from_internal_transform` for `ThreadSafeVarInfo`

* removed `getindex_raw`

* removed redundant docstrings

* fixed tests

* fixed comparisons in tests

* try relative references for images in transformation docs

* another attempt at fixing asset-references

* fixed getindex diagrams in docs

* minor changes to comments

* remove Combinatorics as a test dep, as it's not needed for this PR

* reverted unnecessary change

* disable type-stability tests for models on older Julia versions

* removed seemingly completely unused impl of `setval!`

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Markus Hauru <[email protected]>

* Type-stability tests are now correctly using `rand_prior_true` instead
of `rand`

* `getindex_internal` now calls `getindex` instead of `view`, as the
latter can result in type-instability since transformed variables
typically result in non-view even if input is a view

* Removed seemingly unnecessary definition of `getindex_internal`

* Fixed references to `newmetadata` which has been replaced by `replace_values`

* Made implementation of `recombine` more explicit

* Added docstrings for `untyped_varinfo` and `typed_varinfo`

* Added TODO comment about implementing `view` for `VarInfo`

* Fixed potential infinite recursion as suggested by @mhauru

* added docstring to `from_vec_trnasform_for_size

* Replaced references to `vectorize(dist, x)` with `tovec(x)`

* Fixed docstring

* Update src/extract_priors.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Bump minor version since this is a breaking change

* Apply suggestions from code review

Co-authored-by: Markus Hauru <[email protected]>

* Update src/varinfo.jl

Co-authored-by: Tor Erlend Fjelde <[email protected]>

* Apply suggestions from code review

* Apply suggestions from code review

* Update src/extract_priors.jl

Co-authored-by: Xianda Sun <[email protected]>

* Added fix for product distributions of targets with changing support + tests

* Addeed tests for product of distributions with dynamic support

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Empty commit to trigger CI

* Update test/model.jl

Co-authored-by: Markus Hauru <[email protected]>

* Increase HTML page size threshold for docs

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
  • Loading branch information
5 people authored Aug 21, 2024
1 parent cdd3407 commit 138bd40
Show file tree
Hide file tree
Showing 34 changed files with 1,399 additions and 473 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.28.4"
version = "0.29"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -14,6 +15,7 @@ DataStructures = "0.18"
Distributions = "0.25"
Documenter = "1"
FillArrays = "0.13, 1"
ForwardDiff = "0.10"
LogDensityProblems = "2"
MCMCChains = "5, 6"
StableRNGs = "1"
6 changes: 5 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=tr

makedocs(;
sitename="DynamicPPL",
format=Documenter.HTML(),
# The API index.html page is fairly large, and violates the default HTML page size
# threshold of 200KiB, so we double that.
format=Documenter.HTML(; size_threshold=2^10 * 400),
modules=[DynamicPPL],
pages=[
"Home" => "index.md",
"API" => "api.md",
"Tutorials" => ["tutorials/prob-interface.md"],
"Internals" => ["internals/transformations.md"],
],
checkdocs=:exports,
doctest=false,
)

deploydocs(; repo="github.com/TuringLang/DynamicPPL.jl.git", push_preview=true)
108 changes: 54 additions & 54 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,62 @@ Please see the documentation of [AbstractPPL.jl](https://github.com/TuringLang/A

### Data Structures of Variables

DynamicPPL provides different data structures for samples from the model and their log density.
All of them are subtypes of [`AbstractVarInfo`](@ref).
DynamicPPL provides different data structures used in for storing samples and accumulation of the log-probabilities, all of which are subtypes of [`AbstractVarInfo`](@ref).

```@docs
AbstractVarInfo
```

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

#### `VarInfo`

```@docs
VarInfo
TypedVarInfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
link!
invlink!
```

```@docs
set_flag!
unset_flag!
is_flagged
```

For Gibbs sampling the following functions were added.

```@docs
setgid!
updategid!
```

The following functions were used for sequential Monte Carlo methods.

```@docs
get_num_produce
set_num_produce!
increment_num_produce!
reset_num_produce!
setorder!
set_retained_vns_del_by_spl!
```

```@docs
Base.empty!
```

#### `SimpleVarInfo`

```@docs
SimpleVarInfo
```

### Common API

#### Accumulation of log-probabilities
Expand All @@ -241,7 +290,7 @@ resetlogp!!
```@docs
keys
getindex
DynamicPPL.getindex_raw
DynamicPPL.getindex_internal
push!!
empty!!
isempty
Expand Down Expand Up @@ -269,8 +318,9 @@ DynamicPPL.invlink
DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.link_transform
DynamicPPL.invlink_transform
DynamicPPL.maybe_invlink_before_eval!!
DynamicPPL.reconstruct
```

#### Utils
Expand All @@ -283,56 +333,6 @@ DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```

#### `SimpleVarInfo`

```@docs
SimpleVarInfo
```

#### `VarInfo`

Another data structure is [`VarInfo`](@ref).

```@docs
VarInfo
TypedVarInfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
link!
invlink!
```

```@docs
set_flag!
unset_flag!
is_flagged
```

For Gibbs sampling the following functions were added.

```@docs
setgid!
updategid!
```

The following functions were used for sequential Monte Carlo methods.

```@docs
get_num_produce
set_num_produce!
increment_num_produce!
reset_num_produce!
setorder!
set_retained_vns_del_by_spl!
```

```@docs
Base.empty!
```

### Evaluation Contexts

Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref).
Expand Down
17 changes: 17 additions & 0 deletions docs/src/assets/images/transformations-assume-without-istrans.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
digraph {
# `assume` block
subgraph cluster_assume {
label = "assume";
fontname = "Courier";

assume [shape=box, label=< assume(varinfo, <FONT COLOR="#3B6EA8">@varname</FONT>(x), Normal())>, fontname="Courier"];
without_linking_assume [shape=box, label="f = from_internal_transform(varinfo, varname, dist)", fontname="Courier"];
with_logabsdetjac [shape=box, label="x, logjac = with_logabsdet_jacobian(f, assume_internal(varinfo, varname, dist))", fontname="Courier"];
return_assume [shape=box, label=< <FONT COLOR="#3B6EA8">return</FONT> x, logpdf(dist, x) - logjac, varinfo >, style=dashed, fontname="Courier"];

assume -> without_linking_assume;
without_linking_assume -> with_logabsdetjac;
with_logabsdetjac -> return_assume;
}
}

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions docs/src/assets/images/transformations-assume.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
digraph {
# `assume` block
subgraph cluster_assume {
label = "assume";
fontname = "Courier";

assume [shape=box, label=< assume(varinfo, <FONT COLOR="#3B6EA8">@varname</FONT>(x), Normal())>, fontname="Courier"];
iflinked_assume [label=< <FONT COLOR="#3B6EA8">if</FONT> istrans(varinfo, varname) >, fontname="Courier"];
without_linking_assume [shape=box, label="f = from_internal_transform(varinfo, varname, dist)", fontname="Courier"];
with_linking_assume [shape=box, label="f = from_linked_internal_transform(varinfo, varname, dist)", fontname="Courier"];
with_logabsdetjac [shape=box, label="x, logjac = with_logabsdet_jacobian(f, assume_internal(varinfo, varname, dist))", fontname="Courier"];
return_assume [shape=box, label=< <FONT COLOR="#3B6EA8">return</FONT> x, logpdf(dist, x) - logjac, varinfo >, style=dashed, fontname="Courier"];

assume -> iflinked_assume;
iflinked_assume -> without_linking_assume [label=< <FONT COLOR="#97365B">false</FONT>>, fontname="Courier"];
iflinked_assume -> with_linking_assume [label=< <FONT COLOR="#97365B">true</FONT>>, fontname="Courier"];
without_linking_assume -> with_logabsdetjac;
with_linking_assume -> with_logabsdetjac;
with_logabsdetjac -> return_assume;
}
}

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions docs/src/assets/images/transformations-getindex-with-dist.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
digraph {
# `getindex` block
subgraph cluster_getindex {
label = "getindex";
fontname = "Courier";

getindex [shape=box, label=< x = getindex(varinfo, <FONT COLOR="#3B6EA8">@varname</FONT>(x), Normal()) >, fontname="Courier"];
iflinked_getindex [label=< <FONT COLOR="#3B6EA8">if</FONT> istrans(varinfo, varname) >, fontname="Courier"];
without_linking_getindex [shape=box, label="f = from_internal_transform(varinfo, varname, dist)", fontname="Courier"];
with_linking_getindex [shape=box, label="f = from_linked_internal_transform(varinfo, varname, dist)", fontname="Courier"];
return_getindex [shape=box, label=< <FONT COLOR="#3B6EA8">return</FONT> f(getindex_internal(varinfo, varname)) >, style=dashed, fontname="Courier"];

getindex -> iflinked_getindex;
iflinked_getindex -> without_linking_getindex [label=< <FONT COLOR="#97365B">false</FONT>>, fontname="Courier"];
iflinked_getindex -> with_linking_getindex [label=< <FONT COLOR="#97365B">true</FONT>>, fontname="Courier"];
without_linking_getindex -> return_getindex;
with_linking_getindex -> return_getindex;
}
}

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions docs/src/assets/images/transformations-getindex-without-dist.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
digraph {
# `getindex` block
subgraph cluster_getindex {
label = "getindex";
fontname = "Courier";

getindex [shape=box, label=< x = getindex(varinfo, <FONT COLOR="#3B6EA8">@varname</FONT>(x)) >, fontname="Courier"];
iflinked_getindex [label=< <FONT COLOR="#3B6EA8">if</FONT> istrans(varinfo, varname) >, fontname="Courier"];
without_linking_getindex [shape=box, label="f = from_internal_transform(varinfo, varname)", fontname="Courier"];
with_linking_getindex [shape=box, label="f = from_linked_internal_transform(varinfo, varname)", fontname="Courier"];
return_getindex [shape=box, label=< <FONT COLOR="#3B6EA8">return</FONT> f(getindex_internal(varinfo, varname)) >, style=dashed, fontname="Courier"];

getindex -> iflinked_getindex;
iflinked_getindex -> without_linking_getindex [label=< <FONT COLOR="#97365B">false</FONT>>, fontname="Courier"];
iflinked_getindex -> with_linking_getindex [label=< <FONT COLOR="#97365B">true</FONT>>, fontname="Courier"];
without_linking_getindex -> return_getindex;
with_linking_getindex -> return_getindex;
}
}

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions docs/src/assets/images/transformations.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
digraph {
# Nodes.
tilde_node [shape=box, label="x ~ Normal()", fontname="Courier"];
base_node [shape=box, label=< vn = <FONT COLOR="#3B6EA8">@varname</FONT>(x)<BR/>dist = Normal()<BR/>x, vi = ... >, fontname="Courier"];
assume [shape=box, label="assume(vn, dist, vi)", fontname="Courier"];

iflinked [label=< <FONT COLOR="#3B6EA8">if</FONT> istrans(vi, vn) >, fontname="Courier"];

without_linking [shape=box, label="f = from_internal_transform(vi, vn, dist)", styled=dashed, fontname="Courier"];
with_linking [shape=box, label="f = from_linked_internal_transform(vi, vn, dist)", styled=dashed, fontname="Courier"];

with_logabsdetjac [shape=box, label="x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn, dist))", styled=dashed, fontname="Courier"];
return [shape=box, label=< <FONT COLOR="#3B6EA8">return</FONT> x, logpdf(dist, x) - logjac, vi >, styled=dashed, fontname="Courier"];

# Edges.
tilde_node -> base_node [style=dashed, label=< <FONT COLOR="#3B6EA8">@model</FONT>>, fontname="Courier"]
base_node -> assume [style=dashed, label=" tilde-pipeline", fontname="Courier"];

assume -> iflinked;

iflinked -> without_linking [label=< <FONT COLOR="#97365B">false</FONT>>, fontname="Courier"];
iflinked -> with_linking [label=< <FONT COLOR="#97365B">true</FONT>>, fontname="Courier"];

without_linking -> with_logabsdetjac;
with_linking -> with_logabsdetjac;

with_logabsdetjac -> return;
}
Binary file added docs/src/assets/images/transformations.dot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 138bd40

Please sign in to comment.