Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into sunxd/move_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 15, 2024
2 parents fd1277b + 62a0f19 commit 86eab6b
Show file tree
Hide file tree
Showing 45 changed files with 1,502 additions and 1,267 deletions.
16 changes: 13 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ permissions:
actions: write
contents: read

# Cancel existing tests on the same PR if a new commit is added to a pull request
concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ${{ matrix.runner.os }}
strategy:
fail-fast: false

matrix:
runner:
# Current stable version
Expand Down Expand Up @@ -58,6 +65,9 @@ jobs:
os: macos-latest
arch: aarch64
num_threads: 2
test_group:
- Group1
- Group2

steps:
- uses: actions/checkout@v4
Expand All @@ -73,14 +83,14 @@ jobs:

- uses: julia-actions/julia-runtest@v1
env:
GROUP: All
GROUP: ${{ matrix.test_group }}
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}

- uses: julia-actions/julia-processcoverage@v1

- uses: codecov/codecov-action@v4
- uses: codecov/codecov-action@v5
with:
file: lcov.info
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])'
2 changes: 0 additions & 2 deletions .github/workflows/JuliaPre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,3 @@ jobs:
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: DynamicPPL
19 changes: 7 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.0"
version = "0.32.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -29,16 +29,18 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand All @@ -55,23 +57,16 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.59"
OrderedCollections = "1"
Random = "1.6"
Requires = "1"
ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -18,6 +19,7 @@ Documenter = "1"
DocumenterMermaid = "0.1"
FillArrays = "0.13, 1"
ForwardDiff = "0.10"
JET = "0.9"
LogDensityProblems = "2"
MCMCChains = "5, 6"
StableRNGs = "1"
67 changes: 53 additions & 14 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
@model
```

One can nest models and call another model inside the model function with [`@submodel`](@ref).

```@docs
@submodel
```

### Type

A [`Model`](@ref) can be created by calling the model function, as defined by [`@model`](@ref).
Expand Down Expand Up @@ -110,6 +104,34 @@ Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original
unfix
```

## Models within models

One can include models and call another model inside the model function with `left ~ to_submodel(model)`.

```@docs
to_submodel
```

Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations.

In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref)

```@docs
@submodel
```

In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:

```@docs
prefix
```

Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else

```@docs
returned(::Model)
```

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
Expand All @@ -118,10 +140,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr
@addlogprob!
```

Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref).
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).

```@docs
generated_quantities
returned(::DynamicPPL.Model, ::NamedTuple)
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down Expand Up @@ -243,20 +265,24 @@ AbstractVarInfo

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

#### `VarInfo`
For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:

```@docs
VarInfo
TypedVarInfo
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
#### `VarInfo`

```@docs
link!
invlink!
VarInfo
TypedVarInfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [transformation page](internals/transformations.md).
The [Transformations section below](#Transformations) describes the methods used for this.
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.

```@docs
set_flag!
unset_flag!
Expand Down Expand Up @@ -403,6 +429,19 @@ DynamicPPL.loadstate
DynamicPPL.initialsampler
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

```@docs
DynamicPPL.default_varinfo
```

There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:

```@docs
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
Expand Down
4 changes: 1 addition & 3 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ For example, with the model above we have

```@example varinfo-design
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
)
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
typeof(varinfo_untyped.metadata)
```

Expand Down
53 changes: 53 additions & 0 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module DynamicPPLJETExt

using DynamicPPL: DynamicPPL
using JET: JET

function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo;
only_ddpl::Bool=true,
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
result = if only_ddpl
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
else
JET.report_call(f, argtypes)
end
return length(JET.get_reports(result)) == 0, result
end

function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
model, context, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
end
end

end
12 changes: 5 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function _predictive_samples_to_chains(predictive_samples)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)
returned(model::Model, chain::MCMCChains.Chains)
Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.
Expand All @@ -213,12 +213,12 @@ m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
returned(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing
julia> using Turing
julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
Expand All @@ -237,7 +237,7 @@ julia> model = demo(randn(10));
julia> chain = sample(model, MH(), 10);
julia> generated_quantities(model, chain)
julia> returned(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
Expand All @@ -251,9 +251,7 @@ julia> generated_quantities(model, chain)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
Expand Down
9 changes: 9 additions & 0 deletions ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module DynamicPPLMooncakeExt

using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}

end # module
26 changes: 0 additions & 26 deletions ext/DynamicPPLReverseDiffExt.jl

This file was deleted.

Loading

0 comments on commit 86eab6b

Please sign in to comment.