Working with bounds instead of point observations in RxInfer #312
-
Hi everyone, I am trying to use RxInfer to implement inference but I am running into a small problem. The observations @model function distribution(y)
# something along the lines of y[i] ~ cdf(Gamma)
end
bounds = [[0.1,0.2],[0.2,0.3],[0.1,0.3]]
result = infer(
model = distribution(),
data = (y = bounds, )
) My question is thus, is this possible with RxInfer? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 7 replies
-
Hey @lionelkiel, thanks for checking out RxInfer! Okay, if I understand correctly, your model would look something like this: @model function distribution(y)
β ~ GammaShapeRate(1, 1) # Not conjugate prior but not the point of the discussion, we can make this work
β ~ GammaShapeRate(1, 1)
local z
for i in eachindex(y)
z[i] ~ GammaShapeRate(α, β)
z[i] ~ Uniform(y[i, 1], y[i,2])
end
end The main problem are the messages towards |
Beta Was this translation helpful? Give feedback.
-
Can you explain generative process in this model, before we will jump into the inference? Because your |
Beta Was this translation helpful? Give feedback.
-
Hi Lionel, I have been playing around with your model and got some sensible results. However, I did use some dirty tricks that I'll try to explain here: @model function lionels_model(lower_lims, upper_lims)
α ~ GammaShapeRate(1.0, 1.0)
β ~ GammaShapeRate(1.0, 1.0)
for i in 1:length(lower_lims)
y[i] ~ GammaShapeRate(α, β)
y[i] ~ Uniform(lower_lims[i], upper_lims[i])
end
end Now, the problem with this model is that a Gamma distribution does not have a nice closed-form conjugate prior to its shape parameter. Therefore, the message towards this variable is a bit weird to handle. What I did is that I approximated the product between this message and the Gamma prior with importance sampling and then through moment matching, so I approximate it with a Gamma distribution. The same I did for the product between the Initialization:using RxInfer
using FastGaussQuadrature
using Roots
using StableRNGs
rng = StableRNG(500) Importance sampling projectionThis is the dirty trick part, it's okay if you don't understand function is_project(left, right::GammaDistributionsFamily)
f = (x) -> exp(max(logpdf(left,x), -36.0) + logpdf(right,x) + x)
x, w = gausslaguerre(31)
Z = sum(w .* f.(x))
normalized_f = (x) -> exp(max(logpdf(left,x), -36.0) + logpdf(right,x) + x - log(Z))
expectation_x = sum(w .* normalized_f.(x) .* x)
expectation_logx = sum(w .* normalized_f.(x) .* log.(x))
gss = GammaSufficientStatistics(expectation_x, expectation_logx)
return solve_logpartition_identity(gss, right)
end
BayesBase.prod(::GenericProd, left::GammaDistributionsFamily, right::ContinuousUnivariateLogPdf) = BayesBase.prod(GenericProd(), right, left)
BayesBase.prod(::GenericProd, left::ContinuousUnivariateLogPdf, right::GammaDistributionsFamily) = is_project(left, right)
BayesBase.prod(::GenericProd, left::GammaDistributionsFamily, right::Uniform) = BayesBase.prod(GenericProd(), right, left)
BayesBase.prod(::GenericProd, left::Uniform, right::GammaDistributionsFamily) = is_project(left, right)
struct GammaSufficientStatistics{T}
x::T
logx::T
end
function solve_logpartition_identity(statistics::GammaSufficientStatistics, initial_guess::GammaDistributionsFamily)
f = let statistics = statistics
(α) -> RxInfer.ReactiveMP.digamma(α) - log(α / statistics.x) - statistics.logx
end
α = find_zero(f, shape(initial_guess), Roots.Order0())
β = α / statistics.x
return GammaShapeScale(α, inv(β))
end ReactiveMP rulesThese rules were not in @rule GammaShapeRate(:β, Marginalisation) (q_out::GammaDistributionsFamily, q_α::GammaDistributionsFamily) = GammaShapeRate(1 + mean(q_α), mean(q_out))
@rule GammaShapeRate(:α, Marginalisation) (q_out::GammaDistributionsFamily, q_β::GammaDistributionsFamily) = begin
return ContinuousUnivariateLogPdf(RxInfer.ReactiveMP.DomainSets.HalfLine(), (α) -> α * mean(log, q_β) + (α - 1) * mean(log, q_out) - RxInfer.ReactiveMP.loggamma(α))
end
@rule GammaShapeRate(:out, Marginalisation) (q_α::Any, q_β::Any) = GammaShapeRate(mean(q_α), mean(q_β)) Model and inference constraintsYour model only works when doing Variational Message Passing, so we need a set of inference constraints and an initial state for the iterative update procedure @model function lionels_model(lower_lims, upper_lims)
α ~ GammaShapeRate(1.0, 1.0)
β ~ GammaShapeRate(1.0, 1.0)
for i in 1:length(lower_lims)
y[i] ~ GammaShapeRate(α, β)
y[i] ~ Uniform(lower_lims[i], upper_lims[i])
end
end
constraints = @constraints begin
q(α, β, y) = q(α)q(β)q(y)
end
initialization = @initialization begin
q(α) = GammaShapeRate(1.0, 1.0)
q(β) = GammaShapeRate(1.0, 1.0)
q(y) = GammaShapeRate(1.0, 1.0)
end Data generationNow we're in a shape where we can generate some (demo) data. We draw a random alpha and beta, and sample from the Gamma distribution, and then we generate some arbitrary upper and lower bounds α = 5 * rand(rng)
β = 5 * rand(rng)
n = 100
y = rand(rng, GammaShapeRate(α, β), n)
lower_lims = y .- rand(rng, n)
upper_lims = y .+ rand(rng, n) InferenceThe inference can now be done by result = infer(model = lionels_model(), data = (lower_lims = lower_lims, upper_lims = upper_lims), constraints= constraints, initialization = initialization, iterations = 100) Resultsprintln("True α: $α, estimated α: $(mean(last(result.posteriors[:α])))")
println("True β: $β, estimated β: $(mean(last(result.posteriors[:β])))")
println("True distribution mean: $(α / β), estimated distribution mean: $(mean(last(result.posteriors[:α])) / mean(last(result.posteriors[:β])))") Gives me the following results:
Which looks okay to me. Inference will never be perfect because of the dirty tricks we went through, but maybe @Nimrais would be able to do some magical projection which will increase the inference quality. Hope this solves your problem! |
Beta Was this translation helpful? Give feedback.
Hi Lionel, I have been playing around with your model and got some sensible results. However, I did use some dirty tricks that I'll try to explain here:
First of all, I used this model:
Now, the problem with this model is that a Gamma distribution does not have a nice closed-form conjugate prior to its shape parameter. Therefore, the message towards this variable is a bit weird to handle. What I did is that I approximated the product betw…