From 715dd985ac331f5df26338e6b30ac686015610c4 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 3 Feb 2022 22:18:43 +0000 Subject: [PATCH] TracedModel - CTask (#1770) * Add taped-ctask * Add prefix for libtask APIs. Co-authored-by: Hong Ge --- src/essential/container.jl | 8 ++++++++ src/inference/AdvancedSMC.jl | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/essential/container.jl b/src/essential/container.jl index 74fe845fe..43bd619bf 100644 --- a/src/essential/container.jl +++ b/src/essential/container.jl @@ -70,3 +70,11 @@ function AdvancedPS.reset_logprob!(f::TracedModel) DynamicPPL.resetlogp!!(f.varinfo) return end + +function Libtask.CTask(model::TracedModel) + return Libtask.CTask(model.evaluator[1], model.evaluator[2:end]...) +end + +function Libtask.CTask(model::TracedModel, ::Random.AbstractRNG) + return Libtask.CTask(model) +end \ No newline at end of file diff --git a/src/inference/AdvancedSMC.jl b/src/inference/AdvancedSMC.jl index ba8bef99b..eee476d19 100644 --- a/src/inference/AdvancedSMC.jl +++ b/src/inference/AdvancedSMC.jl @@ -363,7 +363,7 @@ function DynamicPPL.assume( end function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - produce(logpdf(dist, value)) + Libtask.produce(logpdf(dist, value)) return 0, vi end