diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index df59637aa5..0b1dc53377 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -28,23 +28,37 @@ #include namespace gtsam { +HybridGaussianFactor::FactorValuePairs GetFactorValuePairs( + const HybridGaussianConditional::Conditionals &conditionals) { + auto func = [](const GaussianConditional::shared_ptr &conditional) + -> GaussianFactorValuePair { + double value = 0.0; + // Check if conditional is pruned + if (conditional) { + // Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|)) + value = -conditional->logNormalizationConstant(); + } + return {std::dynamic_pointer_cast(conditional), value}; + }; + return HybridGaussianFactor::FactorValuePairs(conditionals, func); +} HybridGaussianConditional::HybridGaussianConditional( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const HybridGaussianConditional::Conditionals &conditionals) : BaseFactor(CollectKeys(continuousFrontals, continuousParents), - discreteParents), + discreteParents, GetFactorValuePairs(conditionals)), BaseConditional(continuousFrontals.size()), conditionals_(conditionals) { - // Calculate logConstant_ as the maximum of the log constants of the + // Calculate logConstant_ as the minimum of the log normalizers of the // conditionals, by visiting the decision tree: - logConstant_ = -std::numeric_limits::infinity(); + logConstant_ = std::numeric_limits::infinity(); conditionals_.visit( [this](const GaussianConditional::shared_ptr &conditional) { if (conditional) { - this->logConstant_ = std::max( - this->logConstant_, conditional->logNormalizationConstant()); + this->logConstant_ = std::min( + this->logConstant_, -conditional->logNormalizationConstant()); } }); } @@ -64,21 +78,6 @@ HybridGaussianConditional::HybridGaussianConditional( DiscreteKeys{discreteParent}, Conditionals({discreteParent}, conditionals)) {} -/* *******************************************************************************/ -// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be -// derived from HybridGaussianFactor, no? -GaussianFactorGraphTree HybridGaussianConditional::add( - const GaussianFactorGraphTree &sum) const { - using Y = GaussianFactorGraph; - auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; - }; - const auto tree = asGaussianFactorGraphTree(); - return sum.empty() ? tree : sum.apply(tree, add); -} - /* *******************************************************************************/ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() const { @@ -86,7 +85,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() // First check if conditional has not been pruned if (gc) { const double Cgm_Kgcm = - this->logConstant_ - gc->logNormalizationConstant(); + -this->logConstant_ - gc->logNormalizationConstant(); // If there is a difference in the covariances, we need to account for // that since the error is dependent on the mode. if (Cgm_Kgcm > 0.0) { @@ -157,7 +156,8 @@ void HybridGaussianConditional::print(const std::string &s, std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; } std::cout << std::endl - << " logNormalizationConstant: " << logConstant_ << std::endl + << " logNormalizationConstant: " << logNormalizationConstant() + << std::endl << std::endl; conditionals_.print( "", [&](Key k) { return formatter(k); }, @@ -216,7 +216,7 @@ std::shared_ptr HybridGaussianConditional::likelihood( -> GaussianFactorValuePair { const auto likelihood_m = conditional->likelihood(given); const double Cgm_Kgcm = - logConstant_ - conditional->logNormalizationConstant(); + -logConstant_ - conditional->logNormalizationConstant(); if (Cgm_Kgcm == 0.0) { return {likelihood_m, 0.0}; } else { @@ -330,7 +330,7 @@ double HybridGaussianConditional::conditionalError( // Check if valid pointer if (conditional) { return conditional->error(continuousValues) + // - logConstant_ - conditional->logNormalizationConstant(); + -logConstant_ - conditional->logNormalizationConstant(); } else { // If not valid, pointer, it means this conditional was pruned, // so we return maximum error. diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index eb2bbb9378..72a9994729 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -51,20 +51,22 @@ class HybridValues; * @ingroup hybrid */ class GTSAM_EXPORT HybridGaussianConditional - : public HybridFactor, - public Conditional { + : public HybridGaussianFactor, + public Conditional { public: using This = HybridGaussianConditional; - using shared_ptr = std::shared_ptr; - using BaseFactor = HybridFactor; - using BaseConditional = Conditional; + using shared_ptr = std::shared_ptr; + using BaseFactor = HybridGaussianFactor; + using BaseConditional = Conditional; /// typedef for Decision Tree of Gaussian Conditionals using Conditionals = DecisionTree; private: Conditionals conditionals_; ///< a decision tree of Gaussian conditionals. - double logConstant_; ///< log of the normalization constant. + ///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))). + ///< Take advantage of the neg-log space so everything is a minimization + double logConstant_; /** * @brief Convert a HybridGaussianConditional of conditionals into @@ -107,8 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional const Conditionals &conditionals); /** - * @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian conditionals. - * The DecisionTree-based constructor is preferred over this one. + * @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian + * conditionals. The DecisionTree-based constructor is preferred over this + * one. * * @param continuousFrontals The continuous frontal variables * @param continuousParents The continuous parent variables @@ -149,7 +152,7 @@ class GTSAM_EXPORT HybridGaussianConditional /// The log normalization constant is max of the the individual /// log-normalization constants. - double logNormalizationConstant() const override { return logConstant_; } + double logNormalizationConstant() const override { return -logConstant_; } /** * Create a likelihood factor for a hybrid Gaussian conditional, @@ -232,14 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional */ void prune(const DecisionTreeFactor &discreteProbs); - /** - * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while - * maintaining the decision tree structure. - * - * @param sum Decision Tree of Gaussian Factor Graphs - * @return GaussianFactorGraphTree - */ - GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; /// @} private: diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 406203db84..dde493685b 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) { auto actual = hybrid_conditional.errorTree(vv); // Check result. - std::vector discrete_keys = {mode}; + DiscreteKeys discrete_keys{mode}; std::vector leaves = {conditionals[0]->error(vv), conditionals[1]->error(vv)}; AlgebraicDecisionTree expected(discrete_keys, leaves); @@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) { EXPECT(continuousParentKeys[0] == X(0)); } +/* ************************************************************************* */ +/// Check error with mode dependent constants. +TEST(HybridGaussianConditional, Error2) { + using namespace mode_dependent_constants; + auto actual = hybrid_conditional.errorTree(vv); + + // Check result. + DiscreteKeys discrete_keys{mode}; + double logNormalizer0 = -conditionals[0]->logNormalizationConstant(); + double logNormalizer1 = -conditionals[1]->logNormalizationConstant(); + double minLogNormalizer = std::min(logNormalizer0, logNormalizer1); + + // Expected error is e(X) + log(|2πΣ|). + // We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative. + std::vector leaves = { + conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer, + conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer}; + AlgebraicDecisionTree expected(discrete_keys, leaves); + + EXPECT(assert_equal(expected, actual, 1e-6)); + + // Check for non-tree version. + for (size_t mode : {0, 1}) { + const HybridValues hv{vv, {{M(0), mode}}}; + EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) - + conditionals[mode]->logNormalizationConstant() - + minLogNormalizer, + hybrid_conditional.error(hv), 1e-8); + } +} + /* ************************************************************************* */ /// Check that the likelihood is proportional to the conditional density given /// the measurements.