From 934ee9535db3abc66f93393ed6daf2edbe4efdb0 Mon Sep 17 00:00:00 2001 From: albertpod Date: Wed, 19 Aug 2020 17:24:10 +0200 Subject: [PATCH 1/2] Extend ruleSPMultiplicationAGPN --- src/engines/julia/update_rules/multiplication.jl | 6 ++++-- test/factor_nodes/test_multiplication.jl | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/engines/julia/update_rules/multiplication.jl b/src/engines/julia/update_rules/multiplication.jl index dd5d8835..f33feacd 100644 --- a/src/engines/julia/update_rules/multiplication.jl +++ b/src/engines/julia/update_rules/multiplication.jl @@ -91,7 +91,7 @@ ruleSPMultiplicationIn1PNP(msg_out::Message{PointMass, Multivariate}, msg_in1::N # Namely, Ax = y, where A ∈ R^{nx1}, x ∈ R^1, and y ∈ R^n. In this case, the matrix A # can be represented by a n-dimensional vector, and x by a scalar. Before computation, # quantities are converted to their proper dimensions (see situational sketch below). -# +# # | a ~ Multivariate -> R^{nx1} # v out ~ Multivariate -> R^n # -->[x]--> @@ -117,4 +117,6 @@ function ruleSPMultiplicationIn1GNP(msg_out::Message{F, Multivariate}, (dims(msg_in1_mult.dist) == 1) || error("Implicit conversion to Univariate failed for $(msg_in1_mult.dist)") return Message(Univariate, GaussianWeightedMeanPrecision, xi=msg_in1_mult.dist.params[:xi][1], w=msg_in1_mult.dist.params[:w][1,1]) -end \ No newline at end of file +end + +ruleSPMultiplicationAGPN(msg_out::Message{F, Multivariate}, msg_in1::Message{PointMass, Multivariate}, msg_a::Nothing) where F<:Gaussian = ruleSPMultiplicationIn1GNP(msg_out, nothing, msg_in1) diff --git a/test/factor_nodes/test_multiplication.jl b/test/factor_nodes/test_multiplication.jl index e8f7ab0a..afb92b08 100644 --- a/test/factor_nodes/test_multiplication.jl +++ b/test/factor_nodes/test_multiplication.jl @@ -77,6 +77,7 @@ end @test isApplicable(SPMultiplicationAGPN, [Message{Gaussian}, Message{PointMass}, Nothing]) @test ruleSPMultiplicationAGPN(Message(Univariate, GaussianWeightedMeanPrecision, xi=1.0, w=3.0), Message(Univariate, PointMass, m=2.0), nothing) == Message(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=12.0) + @test ruleSPMultiplicationAGPN(Message(Multivariate, GaussianWeightedMeanPrecision, xi=[1.0], w=[3.0]), Message(Multivariate, PointMass, m=[2.0]), nothing) == Message(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=12.0) end @testset "SPMultiplicationAPPN" begin From 006de772bcb9336cfc775886eef6fd0e5a79e7ae Mon Sep 17 00:00:00 2001 From: albertpod Date: Tue, 1 Sep 2020 17:34:10 +0200 Subject: [PATCH 2/2] Fix multiplication test --- test/factor_nodes/test_multiplication.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/factor_nodes/test_multiplication.jl b/test/factor_nodes/test_multiplication.jl index afb92b08..17077996 100644 --- a/test/factor_nodes/test_multiplication.jl +++ b/test/factor_nodes/test_multiplication.jl @@ -77,7 +77,7 @@ end @test isApplicable(SPMultiplicationAGPN, [Message{Gaussian}, Message{PointMass}, Nothing]) @test ruleSPMultiplicationAGPN(Message(Univariate, GaussianWeightedMeanPrecision, xi=1.0, w=3.0), Message(Univariate, PointMass, m=2.0), nothing) == Message(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=12.0) - @test ruleSPMultiplicationAGPN(Message(Multivariate, GaussianWeightedMeanPrecision, xi=[1.0], w=[3.0]), Message(Multivariate, PointMass, m=[2.0]), nothing) == Message(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=12.0) + @test ruleSPMultiplicationAGPN(Message(Multivariate, GaussianWeightedMeanPrecision, xi=[1.0], w=[3.0]), Message(Multivariate, PointMass, m=[2.0]), nothing) == Message(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=12.0 + tiny) end @testset "SPMultiplicationAPPN" begin