From 21b4c4c8d30ab75b4e8004c767b8cd411bbf35b2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 8 Oct 2024 12:03:35 -0400 Subject: [PATCH] improve HybridGaussianProductFactor --- gtsam/hybrid/HybridGaussianProductFactor.cpp | 19 +++++++++++++++++-- gtsam/hybrid/HybridGaussianProductFactor.h | 7 +------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianProductFactor.cpp b/gtsam/hybrid/HybridGaussianProductFactor.cpp index 2e95ea8d11..375349f9b8 100644 --- a/gtsam/hybrid/HybridGaussianProductFactor.cpp +++ b/gtsam/hybrid/HybridGaussianProductFactor.cpp @@ -26,39 +26,46 @@ namespace gtsam { using Y = HybridGaussianProductFactor::Y; +/* *******************************************************************************/ static Y add(const Y& y1, const Y& y2) { GaussianFactorGraph result = y1.first; result.push_back(y2.first); return {result, y1.second + y2.second}; }; +/* *******************************************************************************/ HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a, const HybridGaussianProductFactor& b) { return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add)); } +/* *******************************************************************************/ HybridGaussianProductFactor HybridGaussianProductFactor::operator+( const HybridGaussianFactor& factor) const { return *this + factor.asProductFactor(); } +/* *******************************************************************************/ HybridGaussianProductFactor HybridGaussianProductFactor::operator+( const GaussianFactor::shared_ptr& factor) const { return *this + HybridGaussianProductFactor(factor); } +/* *******************************************************************************/ HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=( const GaussianFactor::shared_ptr& factor) { *this = *this + factor; return *this; } +/* *******************************************************************************/ HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=( const HybridGaussianFactor& factor) { *this = *this + factor; return *this; } +/* *******************************************************************************/ void HybridGaussianProductFactor::print(const std::string& s, const KeyFormatter& formatter) const { KeySet keys; @@ -76,11 +83,19 @@ void HybridGaussianProductFactor::print(const std::string& s, } } +/* *******************************************************************************/ +bool HybridGaussianProductFactor::equals( + const HybridGaussianProductFactor& other, double tol) const { + return Base::equals(other, [tol](const Y& a, const Y& b) { + return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol; + }); +} + +/* *******************************************************************************/ HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const { auto emptyGaussian = [](const Y& y) { bool hasNull = - std::any_of(y.first.begin(), - y.first.end(), + std::any_of(y.first.begin(), y.first.end(), [](const GaussianFactor::shared_ptr& ptr) { return !ptr; }); return hasNull ? Y{GaussianFactorGraph(), 0.0} : y; }; diff --git a/gtsam/hybrid/HybridGaussianProductFactor.h b/gtsam/hybrid/HybridGaussianProductFactor.h index 9c2aee74ab..3658a09919 100644 --- a/gtsam/hybrid/HybridGaussianProductFactor.h +++ b/gtsam/hybrid/HybridGaussianProductFactor.h @@ -94,12 +94,7 @@ class HybridGaussianProductFactor * @return true if equal, false otherwise */ bool equals(const HybridGaussianProductFactor& other, - double tol = 1e-9) const { - return Base::equals(other, [tol](const Y& a, const Y& b) { - return a.first.equals(b.first, tol) && - std::abs(a.second - b.second) < tol; - }); - } + double tol = 1e-9) const; /// @}