From cd3c590f32b2047b56f6d843859c39caed3625ad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 19 Sep 2024 21:16:56 -0400 Subject: [PATCH 01/11] common errorTree method and its use in HybridGaussianFactorGraph --- gtsam/hybrid/HybridConditional.cpp | 16 +++++++++++ gtsam/hybrid/HybridConditional.h | 10 +++++++ gtsam/hybrid/HybridFactor.h | 4 +++ gtsam/hybrid/HybridGaussianConditional.h | 10 +++---- gtsam/hybrid/HybridGaussianFactor.h | 4 +-- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 31 ++++------------------ 6 files changed, 42 insertions(+), 33 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 074534b8d5..3cb3bba65d 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const { "HybridConditional::error: conditional type not handled"); } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridConditional::errorTree( + const VectorValues &values) const { + if (auto gc = asGaussian()) { + return AlgebraicDecisionTree(gc->error(values)); + } + if (auto gm = asHybrid()) { + return gm->errorTree(values); + } + if (auto dc = asDiscrete()) { + return AlgebraicDecisionTree(0.0); + } + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); +} + /* ************************************************************************ */ double HybridConditional::logProbability(const HybridValues &values) const { if (auto gc = asGaussian()) { diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index f44ee2bf99..0009d6bd88 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional /// Return the error of the underlying conditional. double error(const HybridValues& values) const override; + /** + * @brief Compute error of the HybridConditional as a tree. + * + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the conditionals involved, and leaf values as the error. + */ + virtual AlgebraicDecisionTree errorTree( + const VectorValues& values) const override; + /// Return the log-probability (or density) of the underlying conditional. double logProbability(const HybridValues& values) const override; diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index ad29dfdca9..fc91e08389 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// Return only the continuous keys for this factor. const KeyVector &continuousKeys() const { return continuousKeys_; } + /// Virtual class to compute tree of linear errors. + virtual AlgebraicDecisionTree errorTree( + const VectorValues &values) const = 0; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 72a9994729..5e585acefe 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional const Conditionals &conditionals); /** - * @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian - * conditionals. The DecisionTree-based constructor is preferred over this - * one. + * @brief Make a Hybrid Gaussian Conditional from + * a vector of Gaussian conditionals. + * The DecisionTree-based constructor is preferred over this one. * * @param continuousFrontals The continuous frontal variables * @param continuousParents The continuous parent variables @@ -208,8 +208,8 @@ class GTSAM_EXPORT HybridGaussianConditional * @return AlgebraicDecisionTree A decision tree on the discrete keys * only, with the leaf values as the error for each assignment. */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; + virtual AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const override; /** * @brief Compute the logProbability of this hybrid Gaussian conditional. diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index a86714863c..8d57ad7da1 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -148,8 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factors involved, and leaf values as the error. */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; + virtual AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const override; /** * @brief Compute the log-likelihood, including the log-normalizing constant. diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 28a0c446fd..0d4e534e17 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -539,36 +539,15 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); - // Iterate over each factor. for (auto &factor : factors_) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. - AlgebraicDecisionTree factor_error; - - auto f = factor; - if (auto hc = dynamic_pointer_cast(factor)) { - f = hc->inner(); - } - - if (auto hybridGaussianCond = - dynamic_pointer_cast(f)) { - // Compute factor error and add it. - error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues); - } else if (auto gaussian = dynamic_pointer_cast(f)) { - // If continuous only, get the (double) error - // and add it to the error_tree - double error = gaussian->error(continuousValues); - // Add the gaussian factor error to every leaf of the error tree. - error_tree = error_tree.apply( - [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f)) { - // If factor at `idx` is discrete-only, we skip. - continue; - } else { - throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f); + if (auto f = std::dynamic_pointer_cast(factor)) { + error_tree = error_tree + f->errorTree(continuousValues); + } else if (auto f = std::dynamic_pointer_cast(factor)) { + error_tree = + error_tree + AlgebraicDecisionTree(f->error(continuousValues)); } } - return error_tree; } From 9c3d7b0f3b95d6ab38e5e462c004413862939d7f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Sep 2024 11:07:02 -0400 Subject: [PATCH 02/11] implement errorTree for HybridNonlinearFactor --- gtsam/hybrid/HybridNonlinearFactor.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 6da846abe5..9852602de4 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -74,6 +74,13 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { /// Decision tree of Gaussian factors indexed by discrete keys. Factors factors_; + /// HybridFactor method implementation. Should not be used. + AlgebraicDecisionTree errorTree( + const VectorValues& continuousValues) const override { + throw std::runtime_error( + "HybridNonlinearFactor::error does not take VectorValues."); + } + public: HybridNonlinearFactor() = default; From 8231de2a92997a79d050a8673be4a30ea6a33d32 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Sep 2024 12:55:21 -0400 Subject: [PATCH 03/11] rename tests to match file --- .../tests/testHybridNonlinearFactorGraph.cpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 621c8708ed..a324585bfb 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -51,7 +51,7 @@ using symbol_shorthand::X; * Test that any linearizedFactorGraph gaussian factors are appended to the * existing gaussian factor graph in the hybrid factor graph. */ -TEST(HybridFactorGraph, GaussianFactorGraph) { +TEST(HybridNonlinearFactorGraph, GaussianFactorGraph) { HybridNonlinearFactorGraph fg; // Add a simple prior factor to the nonlinear factor graph @@ -181,7 +181,7 @@ TEST(HybridGaussianFactorGraph, HybridNonlinearFactor) { /***************************************************************************** * Test push_back on HFG makes the correct distinction. */ -TEST(HybridFactorGraph, PushBack) { +TEST(HybridNonlinearFactorGraph, PushBack) { HybridNonlinearFactorGraph fg; auto nonlinearFactor = std::make_shared>(); @@ -240,7 +240,7 @@ TEST(HybridFactorGraph, PushBack) { /**************************************************************************** * Test construction of switching-like hybrid factor graph. */ -TEST(HybridFactorGraph, Switching) { +TEST(HybridNonlinearFactorGraph, Switching) { Switching self(3); EXPECT_LONGS_EQUAL(7, self.nonlinearFactorGraph.size()); @@ -250,7 +250,7 @@ TEST(HybridFactorGraph, Switching) { /**************************************************************************** * Test linearization on a switching-like hybrid factor graph. */ -TEST(HybridFactorGraph, Linearization) { +TEST(HybridNonlinearFactorGraph, Linearization) { Switching self(3); // Linearize here: @@ -263,7 +263,7 @@ TEST(HybridFactorGraph, Linearization) { /**************************************************************************** * Test elimination tree construction */ -TEST(HybridFactorGraph, EliminationTree) { +TEST(HybridNonlinearFactorGraph, EliminationTree) { Switching self(3); // Create ordering. @@ -372,7 +372,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { /**************************************************************************** * Test partial elimination */ -TEST(HybridFactorGraph, Partial_Elimination) { +TEST(HybridNonlinearFactorGraph, Partial_Elimination) { Switching self(3); auto linearizedFactorGraph = self.linearizedFactorGraph; @@ -401,7 +401,8 @@ TEST(HybridFactorGraph, Partial_Elimination) { EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)})); } -TEST(HybridFactorGraph, PrintErrors) { +/* ****************************************************************************/ +TEST(HybridNonlinearFactorGraph, PrintErrors) { Switching self(3); // Get nonlinear factor graph and add linear factors to be holistic @@ -424,7 +425,7 @@ TEST(HybridFactorGraph, PrintErrors) { /**************************************************************************** * Test full elimination */ -TEST(HybridFactorGraph, Full_Elimination) { +TEST(HybridNonlinearFactorGraph, Full_Elimination) { Switching self(3); auto linearizedFactorGraph = self.linearizedFactorGraph; @@ -492,7 +493,7 @@ TEST(HybridFactorGraph, Full_Elimination) { /**************************************************************************** * Test printing */ -TEST(HybridFactorGraph, Printing) { +TEST(HybridNonlinearFactorGraph, Printing) { Switching self(3); auto linearizedFactorGraph = self.linearizedFactorGraph; @@ -784,7 +785,7 @@ conditional 2: Hybrid P( x2 | m0 m1) * The issue arises if we eliminate a landmark variable first since it is not * connected to a HybridFactor. */ -TEST(HybridFactorGraph, DefaultDecisionTree) { +TEST(HybridNonlinearFactorGraph, DefaultDecisionTree) { HybridNonlinearFactorGraph fg; // Add a prior on pose x0 at the origin. From 9cbc7540d6bc83a476c1c20b4d3d2e4cb6ad19ee Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Sep 2024 15:28:19 -0400 Subject: [PATCH 04/11] add error(HybridValues) to HybridNonlinearFactorGraph --- gtsam/hybrid/HybridNonlinearFactorGraph.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 54dc9e93fc..5a09f18d45 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -86,6 +86,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { */ std::shared_ptr linearize( const Values& continuousValues) const; + + /// Expose error(const HybridValues&) method. + using Base::error; + /// @} }; From c71f0336e2c1a4cdfa873cb08d966eddf273cecf Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Sep 2024 15:30:06 -0400 Subject: [PATCH 05/11] add HybridNonlinearFactorGraph::error test --- .../tests/testHybridNonlinearFactorGraph.cpp | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index a324585bfb..347cc5f1fe 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -401,6 +401,37 @@ TEST(HybridNonlinearFactorGraph, Partial_Elimination) { EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)})); } +/* ****************************************************************************/ +TEST(HybridNonlinearFactorGraph, Error) { + Switching self(3); + HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph; + + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 0}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(152.791759469, fg.error(values), 1e-9); + } + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 1}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(151.598612289, fg.error(values), 1e-9); + } + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 0}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(151.703972804, fg.error(values), 1e-9); + } + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 1}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(151.609437912, fg.error(values), 1e-9); + } +} + /* ****************************************************************************/ TEST(HybridNonlinearFactorGraph, PrintErrors) { Switching self(3); From d145872916868e6432ea1d574c9167356ab39ab0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Sep 2024 15:30:24 -0400 Subject: [PATCH 06/11] comment update --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 0d4e534e17..01ecfe5ac8 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -329,8 +329,8 @@ static std::shared_ptr createDiscreteFactor( // Logspace version of: // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); - // We take negative of the logNormalizationConstant `log(1/k)` - // to get `log(k)`. + // We take negative of the logNormalizationConstant `log(k)` + // to get `log(1/k) = log(\sqrt{|2πΣ|})`. return -factor->error(kEmpty) - conditional->logNormalizationConstant(); }; From 939fdcc7201729a2ef686087d6a178ba5cf14ef0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Sep 2024 16:20:49 -0400 Subject: [PATCH 07/11] HybridGaussianFactorGraph::errorTree is better encompassing --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 01ecfe5ac8..a6fe955eb3 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -542,10 +542,15 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( // Iterate over each factor. for (auto &factor : factors_) { if (auto f = std::dynamic_pointer_cast(factor)) { + // Check for HybridFactor, and call errorTree error_tree = error_tree + f->errorTree(continuousValues); - } else if (auto f = std::dynamic_pointer_cast(factor)) { - error_tree = - error_tree + AlgebraicDecisionTree(f->error(continuousValues)); + } else if (auto f = std::dynamic_pointer_cast(factor)) { + // Skip discrete factors + continue; + } else { + // Everything else is a continuous only factor + HybridValues hv(continuousValues, DiscreteValues()); + error_tree = error_tree + AlgebraicDecisionTree(factor->error(hv)); } } return error_tree; From 4c8224800424acda8e673cb018fb8a5a400d2d5a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 21 Sep 2024 03:09:00 -0400 Subject: [PATCH 08/11] remove virtual --- gtsam/hybrid/HybridConditional.h | 2 +- gtsam/hybrid/HybridGaussianFactor.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 0009d6bd88..51eeeb5bb6 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -186,7 +186,7 @@ class GTSAM_EXPORT HybridConditional * @return AlgebraicDecisionTree A decision tree with the same keys * as the conditionals involved, and leaf values as the error. */ - virtual AlgebraicDecisionTree errorTree( + AlgebraicDecisionTree errorTree( const VectorValues& values) const override; /// Return the log-probability (or density) of the underlying conditional. diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 8d57ad7da1..3bf9f3dfd8 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -148,7 +148,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factors involved, and leaf values as the error. */ - virtual AlgebraicDecisionTree errorTree( + AlgebraicDecisionTree errorTree( const VectorValues &continuousValues) const override; /** From d81cd82b9a6e27e51104faf00c9915656a670799 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 21 Sep 2024 03:24:01 -0400 Subject: [PATCH 09/11] check for potentiall pruning --- gtsam/hybrid/HybridGaussianFactor.cpp | 21 ++++++++++++++++++--- gtsam/hybrid/HybridGaussianFactor.h | 4 ++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 2fbd4bd888..a4e0bf8742 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -151,12 +151,26 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() return {factors_, wrap}; } +/* *******************************************************************************/ +double HybridGaussianFactor::potentiallyPrunedComponentError( + const sharedFactor &gf, const VectorValues &values) const { + // Check if valid pointer + if (gf) { + return gf->error(values); + } else { + // If not valid, pointer, it means this component was pruned, + // so we return maximum error. + // This way the negative exponential will give + // a probability value close to 0.0. + return std::numeric_limits::max(); + } +} /* *******************************************************************************/ AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [&continuousValues](const sharedFactor &gf) { - return gf->error(continuousValues); + auto errorFunc = [this, &continuousValues](const sharedFactor &gf) { + return this->potentiallyPrunedComponentError(gf, continuousValues); }; DecisionTree error_tree(factors_, errorFunc); return error_tree; @@ -164,8 +178,9 @@ AlgebraicDecisionTree HybridGaussianFactor::errorTree( /* *******************************************************************************/ double HybridGaussianFactor::error(const HybridValues &values) const { + // Directly index to get the component, no need to build the whole tree. const sharedFactor gf = factors_(values.discrete()); - return gf->error(values.continuous()); + return potentiallyPrunedComponentError(gf, values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 3bf9f3dfd8..b1b93dc320 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -169,6 +169,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// @} private: + /// Helper method to compute the error of a component. + double potentiallyPrunedComponentError( + const sharedFactor &gf, const VectorValues &continuousValues) const; + #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; From 821b22f6f86292d0b1039285d9ae941f803b3b54 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 21 Sep 2024 03:24:29 -0400 Subject: [PATCH 10/11] remove unnecessary code in child class --- gtsam/hybrid/HybridGaussianConditional.cpp | 34 ------------------ gtsam/hybrid/HybridGaussianConditional.h | 41 ---------------------- 2 files changed, 75 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 0b1dc53377..fb943366cb 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -323,40 +323,6 @@ AlgebraicDecisionTree HybridGaussianConditional::logProbability( return DecisionTree(conditionals_, probFunc); } -/* ************************************************************************* */ -double HybridGaussianConditional::conditionalError( - const GaussianConditional::shared_ptr &conditional, - const VectorValues &continuousValues) const { - // Check if valid pointer - if (conditional) { - return conditional->error(continuousValues) + // - -logConstant_ - conditional->logNormalizationConstant(); - } else { - // If not valid, pointer, it means this conditional was pruned, - // so we return maximum error. - // This way the negative exponential will give - // a probability value close to 0.0. - return std::numeric_limits::max(); - } -} - -/* *******************************************************************************/ -AlgebraicDecisionTree HybridGaussianConditional::errorTree( - const VectorValues &continuousValues) const { - auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { - return conditionalError(conditional, continuousValues); - }; - DecisionTree error_tree(conditionals_, errorFunc); - return error_tree; -} - -/* *******************************************************************************/ -double HybridGaussianConditional::error(const HybridValues &values) const { - // Directly index to get the conditional, no need to build the whole tree. - auto conditional = conditionals_(values.discrete()); - return conditionalError(conditional, values.continuous()); -} - /* *******************************************************************************/ double HybridGaussianConditional::logProbability( const HybridValues &values) const { diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 5e585acefe..4a5fdcc89e 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -174,43 +174,6 @@ class GTSAM_EXPORT HybridGaussianConditional AlgebraicDecisionTree logProbability( const VectorValues &continuousValues) const; - /** - * @brief Compute the error of this hybrid Gaussian conditional. - * - * This requires some care, as different components may have - * different normalization constants. Let's consider p(x|y,m), where m is - * discrete. We need the error to satisfy the invariant: - * - * error(x;y,m) = K - log(probability(x;y,m)) - * - * For all x,y,m. But note that K, the (log) normalization constant defined - * in Conditional.h, should not depend on x, y, or m, only on the parameters - * of the density. Hence, we delegate to the underlying Gaussian - * conditionals, indexed by m, which do satisfy: - * - * log(probability_m(x;y)) = K_m - error_m(x;y) - * - * We resolve by having K == max(K_m) and - * - * error(x;y,m) = error_m(x;y) + K - K_m - * - * which also makes error(x;y,m) >= 0 for all x,y,m. - * - * @param values Continuous values and discrete assignment. - * @return double - */ - double error(const HybridValues &values) const override; - - /** - * @brief Compute error of the HybridGaussianConditional as a tree. - * - * @param continuousValues The continuous VectorValues. - * @return AlgebraicDecisionTree A decision tree on the discrete keys - * only, with the leaf values as the error for each assignment. - */ - virtual AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const override; - /** * @brief Compute the logProbability of this hybrid Gaussian conditional. * @@ -241,10 +204,6 @@ class GTSAM_EXPORT HybridGaussianConditional /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; - /// Helper method to compute the error of a conditional. - double conditionalError(const GaussianConditional::shared_ptr &conditional, - const VectorValues &continuousValues) const; - #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; From 9f9032f9040934dcf2df699d64a1c9c3d705f5f6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 22 Sep 2024 11:06:56 -0400 Subject: [PATCH 11/11] add test for errorTree in incremental scenario --- .../tests/testHybridGaussianFactorGraph.cpp | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 1530069aa1..f19fd29b17 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -706,6 +706,55 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) { EXPECT(assert_equal(expected, errorTree, 1e-9)); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph errorTree during +// incremental operation +TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { + Switching s(4); + + HybridGaussianFactorGraph graph; + graph.push_back(s.linearizedFactorGraph.at(0)); // f(X0) + graph.push_back(s.linearizedFactorGraph.at(1)); // f(X0, X1, M0) + graph.push_back(s.linearizedFactorGraph.at(2)); // f(X1, X2, M1) + graph.push_back(s.linearizedFactorGraph.at(4)); // f(X1) + graph.push_back(s.linearizedFactorGraph.at(5)); // f(X2) + graph.push_back(s.linearizedFactorGraph.at(7)); // f(M0) + graph.push_back(s.linearizedFactorGraph.at(8)); // f(M0, M1) + + HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); + EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = graph.errorTree(delta.continuous()); + + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + std::vector leaves = {0.99985581, 0.4902432, 0.51936941, + 0.0097568009}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + + graph = HybridGaussianFactorGraph(); + graph.push_back(*hybridBayesNet); + graph.push_back(s.linearizedFactorGraph.at(3)); // f(X2, X3, M2) + graph.push_back(s.linearizedFactorGraph.at(6)); // f(X3) + + hybridBayesNet = graph.eliminateSequential(); + EXPECT_LONGS_EQUAL(7, hybridBayesNet->size()); + + delta = hybridBayesNet->optimize(); + auto error_tree2 = graph.errorTree(delta.continuous()); + + discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; + leaves = {0.50985198, 0.0097577296, 0.50009425, 0, + 0.52922138, 0.029127133, 0.50985105, 0.0097567964}; + AlgebraicDecisionTree expected_error2(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); +} + /* ****************************************************************************/ // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment.