From 003d34463e5350770abfdb4eebcc1716eaa13e20 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 14:29:49 +0000 Subject: [PATCH 1/4] add hooks for acclogp!! depending on whether it's from an `assume` or `observe` statement --- src/context_implementations.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 494fb0e47..7d3c796cb 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,6 +14,15 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false +# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. +function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(context, vi, logp) +end + +function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp!!(context, vi, logp) +end + # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -115,7 +124,7 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) value, logp, vi = tilde_assume(context, right, vn, vi) - return value, acclogp!!(context, vi, logp) + return value, acclogp_assume!!(context, vi, logp) end # observe @@ -181,7 +190,7 @@ probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp!!(context, vi, logp) + return left, acclogp_observe!!(context, vi, logp) end function assume(rng, spl::Sampler, dist) @@ -383,7 +392,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp!!(context, vi, logp), vi + return value, acclogp_assume!!(context, vi, logp), vi end # `dot_assume` @@ -539,6 +548,7 @@ function get_and_set_val!( if istrans(vi) push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. + # FIXME: This is not great. acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) @@ -634,7 +644,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) logp, vi = dot_tilde_observe(context, right, left, vi) - return left, acclogp!!(context, vi, logp) + return left, acclogp_observe!!(context, vi, logp) end # Falls back to non-sampler definition. From 178cdf1b7a05b6f575630b9f8e9d5c1a0fdf7216 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 14:30:31 +0000 Subject: [PATCH 2/4] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4221d64fe..316dfb90a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.2" +version = "0.24.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From cdd392153333ee5812249511751f7d1840c5be77 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 15:38:14 +0000 Subject: [PATCH 3/4] Update src/context_implementations.jl --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 7d3c796cb..2b28b44a9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -549,7 +549,7 @@ function get_and_set_val!( push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. # FIXME: This is not great. - acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) + acclogp_assume!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. settrans!!.((vi,), true, vns) else From 5591002b7a19327abb1eb28a5533d8a3a5f51412 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 22:56:06 +0000 Subject: [PATCH 4/4] Update Project.toml Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 316dfb90a..985189724 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.3" +version = "0.24.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"