Skip to content

Commit

Permalink
Merge pull request #439 from ReactiveBayes/generic-transition
Browse files Browse the repository at this point in the history
Generic implementaton of Transition node
  • Loading branch information
bvdmitri authored Jan 22, 2025
2 parents 1d28b19 + e15f559 commit 58a213a
Show file tree
Hide file tree
Showing 19 changed files with 861 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DiffResults = "1.1.0"
Distributions = "0.24, 0.25"
DomainIntegrals = "0.3.2, 0.4"
DomainSets = "0.5.2, 0.6, 0.7"
ExponentialFamily = "1.6.0"
ExponentialFamily = "1.7.0"
ExponentialFamilyProjection = "1.2"
FastCholesky = "1.3.0"
FastGaussQuadrature = "0.4, 0.5"
Expand Down
1 change: 1 addition & 0 deletions src/nodes/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include("predefined/gamma_shape_rate.jl")
include("predefined/beta.jl")
include("predefined/categorical.jl")
include("predefined/matrix_dirichlet.jl")
include("predefined/tensor_dirichlet.jl")
include("predefined/dirichlet.jl")
include("predefined/bernoulli.jl")
include("predefined/gcv.jl")
Expand Down
10 changes: 10 additions & 0 deletions src/nodes/predefined/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import SpecialFunctions: loggamma
import Base.Broadcast: BroadcastFunction

@node TensorDirichlet Stochastic [out, a]

@average_energy TensorDirichlet (q_out::TensorDirichlet, q_a::PointMass) = begin
m_a = mean(q_a)
logmean = mean(BroadcastFunction(log), q_out)
return sum(-loggamma.(sum(m_a, dims = 1)) .+ sum(loggamma.(m_a), dims = 1) .- sum((m_a .- 1.0) .* logmean, dims = 1))
end
53 changes: 51 additions & 2 deletions src/nodes/predefined/transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,30 @@ import Base.Broadcast: BroadcastFunction

struct Transition end

@node Transition Stochastic [out, in, a]
ReactiveMP.sdtype(::Type{Transition}) = ReactiveMP.Stochastic()
ReactiveMP.is_predefined_node(::Type{Transition}) = ReactiveMP.PredefinedNodeFunctionalForm()

function ReactiveMP.prepare_interfaces_generic(fform::Type{Transition}, interfaces::AbstractVector)
return map(enumerate(interfaces)) do (index, (name, variable))
return ReactiveMP.NodeInterface(ReactiveMP.alias_interface(fform, index, name), variable)
end
end

function ReactiveMP.alias_interface(::Type{Transition}, index, name)
if name === :out && index === 1
return :out
elseif name === :in && index === 2
return :in
elseif name === :in && index === 3
return :a
elseif name === :in && index >= 4
return Symbol(:T, index - 3)
end
end

function ReactiveMP.collect_factorisation(::Type{Transition}, t::Tuple)
return t
end

@average_energy Transition (q_out::Any, q_in::Any, q_a::MatrixDirichlet) = begin
return -probvec(q_out)' * mean(BroadcastFunction(log), q_a) * probvec(q_in)
Expand All @@ -19,9 +42,35 @@ end
# The reason is that we don't want to take log of zeros in the matrix `q_a` (if there are any)
# The trick here is that if RHS matrix has zero inputs, than the corresponding entries of the `contingency_matrix` matrix
# should also be zeros (see corresponding @marginalrule), so at the end `log(tiny) * 0` should not influence the result.
return -ReactiveMP.mul_trace(components(q_out_in)', mean(BroadcastFunction(clamplog), q_a))
result = -ReactiveMP.mul_trace(components(q_out_in)', mean(BroadcastFunction(clamplog), q_a))
return result
end

@average_energy Transition (q_out::Any, q_in::Any, q_a::PointMass) = begin
return -probvec(q_out)' * mean(BroadcastFunction(clamplog), q_a) * probvec(q_in)
end

function score(::AverageEnergy, ::Type{<:Transition}, ::Val{mnames}, marginals::Tuple{<:Marginal{<:Contingency}, <:Marginal{<:TensorDirichlet}}, ::Nothing) where {mnames}
q_contingency, q_a = getdata.(marginals)
return -sum(mean(BroadcastFunction(log), q_a) .* components(q_contingency))
end

function __reduce_td_from_messages(messages, q_A, interface_index)
vmp = clamp.(exp.(mean(BroadcastFunction(log), q_A)), tiny, Inf)
probvecs = probvec.(messages)
for (i, vector) in enumerate(probvecs)
if i interface_index
actual_index = i + 1
else
actual_index = i
end
v = view(vector, :)
localdims = ntuple(x -> x == actual_index::Int64 ? length(v) : 1, ndims(vmp))
vmp .*= reshape(v, localdims)
end
dims = ntuple(x -> x interface_index ? x + 1 : x, ndims(vmp) - 1)
vmp = sum(vmp, dims = dims)
msg = reshape(vmp, :)
msg ./= sum(msg)
return Categorical(msg)
end
4 changes: 4 additions & 0 deletions src/rules/predefined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ include("transition/marginals.jl")
include("transition/out.jl")
include("transition/in.jl")
include("transition/a.jl")
include("transition/t.jl")

include("continuous_transition/y.jl")
include("continuous_transition/x.jl")
Expand Down Expand Up @@ -184,3 +185,6 @@ include("delta/cvi/marginals.jl")
include("half_normal/out.jl")

include("binomial_polya/beta.jl")

include("tensor_dirichlet/out.jl")
include("tensor_dirichlet/marginals.jl")
4 changes: 4 additions & 0 deletions src/rules/tensor_dirichlet/marginals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

@marginalrule TensorDirichlet(:out_a) (m_out::TensorDirichlet, m_a::PointMass) = begin
return convert_paramfloattype((out = prod(ClosedProd(), TensorDirichlet(mean(m_a)), m_out), a = m_a))
end
4 changes: 4 additions & 0 deletions src/rules/tensor_dirichlet/out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

@rule TensorDirichlet(:out, Marginalisation) (m_a::PointMass,) = TensorDirichlet(mean(m_a))

@rule TensorDirichlet(:out, Marginalisation) (q_a::PointMass,) = TensorDirichlet(mean(q_a))
13 changes: 13 additions & 0 deletions src/rules/transition/a.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,16 @@ end
@rule Transition(:a, Marginalisation) (q_out_in::Contingency,) = begin
return MatrixDirichlet(components(q_out_in) .+ 1)
end

ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{:a},
vconstraint::Marginalisation,
messages_names::Nothing,
messages::Nothing,
marginals_names::Val{m_names} where {m_names},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) = TensorDirichlet(components(getdata(first(marginals))) .+ 1), addons
15 changes: 15 additions & 0 deletions src/rules/transition/in.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ end
normalize!(p, 1)
return Categorical(p)
end

function ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{:in},
vconstraint::Marginalisation,
messages_names::Val{m_names},
messages::Tuple,
marginals_names::Val{(:a,)},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) where {m_names}
return __reduce_td_from_messages(messages, first(marginals), 2), addons
end
13 changes: 13 additions & 0 deletions src/rules/transition/marginals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,16 @@ end
m_in_2 = @call_rule Transition(:in, Marginalisation) (m_out = m_out, m_a = m_a, meta = meta)
return convert_paramfloattype((out = m_out, in = prod(ClosedProd(), m_in_2, m_in), a = m_a))
end

@marginalrule Transition(:out_in) (m_out::PointMass, m_in::Categorical, q_a::PointMass) = begin
m_in_2 = @call_rule Transition(:in, Marginalisation) (m_out = m_out, q_a = q_a)
return convert_paramfloattype((out = m_out, in = prod(ClosedProd(), m_in, m_in_2)))
end

outer_product(vs) = prod.(Iterators.product(vs...))

function marginalrule(
::Type{<:Transition}, ::Val{marginal_symbol}, ::Val{message_names}, messages::Tuple, ::Val{marginal_names}, marginals::Tuple, ::Any, ::Any
) where {marginal_symbol, message_names, marginal_names}
return Contingency(outer_product(probvec.(messages)) .* clamp.(exp.(mean(BroadcastFunction(log), first(marginals))), tiny, huge))
end
15 changes: 15 additions & 0 deletions src/rules/transition/out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,18 @@ end
@logscale 0
return @call_rule Transition(:out, Marginalisation) (m_in = m_in, m_a = q_a, meta = meta, addons = getaddons())
end

function ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{:out},
vconstraint::Marginalisation,
messages_names::Val{m_names},
messages::Tuple,
marginals_names::Val{(:a,)},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) where {m_names}
return __reduce_td_from_messages(messages, first(marginals), 1), addons
end
17 changes: 17 additions & 0 deletions src/rules/transition/t.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import Base.Broadcast: BroadcastFunction

function ReactiveMP.rule(
fform::Type{<:Transition},
on::Val{S},
vconstraint::Marginalisation,
messages_names::Val{m_names},
messages::Tuple,
marginals_names::Val{(:a,)},
marginals::Tuple,
meta::Any,
addons::Any,
::Any
) where {S, m_names}
interface_index = parse(Int, String(S)[2:end]) + 2
return __reduce_td_from_messages(messages, first(marginals), interface_index), addons
end
56 changes: 56 additions & 0 deletions test/nodes/predefined/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

@testitem "TensorDirichletNode" begin
using ReactiveMP, Random, BayesBase, ExponentialFamily, Distributions, StableRNGs

@testset "AverageEnergy" begin
begin
rng = StableRNG(123456)
for i in 1:100
α = rand(rng, 2, 2)
a = rand(rng, 2, 2)
q_out = TensorDirichlet(α)
q_a = PointMass(a)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing))
avg_energy = score(AverageEnergy(), TensorDirichlet, Val{(:out, :a)}(), marginals, nothing)

q_out = MatrixDirichlet(α)
q_a = PointMass(a)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing))
avg_energy_matrix = score(AverageEnergy(), MatrixDirichlet, Val{(:out, :a)}(), marginals, nothing)

@test avg_energy avg_energy_matrix
end
end

begin
for rank in 3:5
for dim in 2:5
for i in 1:100
dims = ntuple(d -> dim, rank)
α = rand(rng, dims...)
a = rand(rng, dims...)

q_out = TensorDirichlet(α)
q_a = PointMass(a)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_a, false, false, nothing))
avg_energy = score(AverageEnergy(), TensorDirichlet, Val{(:out, :a)}(), marginals, nothing)

q_out = Dirichlet.(eachslice(α, dims = ntuple(d -> d + 1, rank - 1)))
q_a = PointMass.(eachslice(a, dims = ntuple(d -> d + 1, rank - 1)))

avg_energy_matrix = 0.0
for (dir, a) in zip(q_out, q_a)
marginals = (Marginal(dir, false, false, nothing), Marginal(a, false, false, nothing))
avg_energy_matrix += score(AverageEnergy(), Dirichlet, Val{(:out, :a)}(), marginals, nothing)
end

@test avg_energy avg_energy_matrix
end
end
end
end
end
end
86 changes: 86 additions & 0 deletions test/nodes/predefined/transition_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
@testitem "TransitionNode" begin
using Test, ReactiveMP, Random, Distributions, BayesBase, ExponentialFamily

@testset "Transition node properties" begin
@test ReactiveMP.sdtype(Transition) == Stochastic()
@test ReactiveMP.alias_interface(Transition, 1, :out) == :out
@test ReactiveMP.alias_interface(Transition, 2, :in) == :in
@test ReactiveMP.alias_interface(Transition, 3, :in) == :a
@test ReactiveMP.alias_interface(Transition, 4, :in) == :T1

@test ReactiveMP.collect_factorisation(Transition, ()) == ()
end
@testset "AverageEnergy(q_out_in::Contingency, q_a::MatrixDirichlet)" begin end

@testset "AverageEnergy(q_out_in::Contingency, q_a::PointMass)" begin
contingency_matrix = [0.2 0.3; 0.4 0.1]
a_matrix = [0.7 0.3; 0.2 0.8]

q_out_in = Contingency(contingency_matrix)
q_a = PointMass(a_matrix)

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

# Expected value calculated by hand
expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) expected

contingency_matrix = [0.2 0.3; 0.4 0.1]
a_matrix = [1.0 0.0; 0.0 1.0]

q_out_in = Contingency(contingency_matrix)
q_a = PointMass(a_matrix)

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) expected

contingency_matrix = prod.(Iterators.product([0, 1, 0], [0.1, 0.4, 0.5]))
a_matrix = diageye(3)

q_out_in = Contingency(contingency_matrix)
q_a = PointMass(a_matrix)

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) expected

contingency_matrix = [0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0]
q_out_in = Contingency(contingency_matrix)
q_a = PointMass(diageye(3))

marginals = (Marginal(q_out_in, false, false, nothing), Marginal(q_a, false, false, nothing))

expected = -sum(contingency_matrix .* log.(clamp.(a_matrix, tiny, Inf)))
@test score(AverageEnergy(), Transition, Val{(:out_in, :a)}(), marginals, nothing) expected
end

@testset "AverageEnergy(q_out::Any, q_in::Any, q_a::PointMass)" begin
q_out = Categorical([0.3, 0.7])
q_in = Categorical([0.8, 0.2])
q_a = PointMass([0.7 0.3; 0.2 0.8])

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_in, false, false, nothing), Marginal(q_a, false, false, nothing))

contingency = probvec(q_out) * probvec(q_in)'
expected = -sum(contingency .* log.(clamp.(mean(q_a), tiny, Inf)))

@test score(AverageEnergy(), Transition, Val{(:out, :in, :a)}(), marginals, nothing) expected

q_out = Categorical([0.0, 1.0])
q_in = Categorical([0.0, 1.0])
q_a = PointMass([1.0 0.0; 1.0 0.0])

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_in, false, false, nothing), Marginal(q_a, false, false, nothing))

contingency = probvec(q_out) * probvec(q_in)'

expected = -sum(contingency .* log.(clamp.(mean(q_a), tiny, Inf)))
@test score(AverageEnergy(), Transition, Val{(:out, :in, :a)}(), marginals, nothing) expected
end
end
12 changes: 12 additions & 0 deletions test/rules/transition/a_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@
input = (q_out_in = Contingency(diageye(3)),), output = MatrixDirichlet([1.333333333333333 1 1; 1 1.3333333333333 1; 1 1 1.33333333333333333])
)]
end

@testset "Variational Bayes: (q_out_in_t1::Contingency)" begin
@test_rules [check_type_promotion = false] Transition(:a, Marginalisation) [(
input = (q_out_in_t1 = Contingency(ones(3, 3, 3)),), output = TensorDirichlet(ones(3, 3, 3) .+ (1 / 27))
)]
end

@testset "Variational Bayes: (q_out_in_t1_t2::Contingency)" begin
@test_rules [check_type_promotion = false] Transition(:a, Marginalisation) [(
input = (q_out_in_t1_t2 = Contingency(ones(3, 3, 3, 3)),), output = TensorDirichlet(ones(3, 3, 3, 3) .+ (1 / 81))
)]
end
end
Loading

0 comments on commit 58a213a

Please sign in to comment.