Skip to content

Commit

Permalink
improve HybridGaussianProductFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Oct 8, 2024
1 parent 874ba67 commit 21b4c4c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
19 changes: 17 additions & 2 deletions gtsam/hybrid/HybridGaussianProductFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,46 @@ namespace gtsam {

using Y = HybridGaussianProductFactor::Y;

/* *******************************************************************************/
static Y add(const Y& y1, const Y& y2) {
GaussianFactorGraph result = y1.first;
result.push_back(y2.first);
return {result, y1.second + y2.second};
};

/* *******************************************************************************/
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
const HybridGaussianProductFactor& b) {
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
}

/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const HybridGaussianFactor& factor) const {
return *this + factor.asProductFactor();
}

/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const GaussianFactor::shared_ptr& factor) const {
return *this + HybridGaussianProductFactor(factor);
}

/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const GaussianFactor::shared_ptr& factor) {
*this = *this + factor;
return *this;
}

/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const HybridGaussianFactor& factor) {
*this = *this + factor;
return *this;
}

/* *******************************************************************************/
void HybridGaussianProductFactor::print(const std::string& s,
const KeyFormatter& formatter) const {
KeySet keys;
Expand All @@ -76,11 +83,19 @@ void HybridGaussianProductFactor::print(const std::string& s,
}
}

/* *******************************************************************************/
bool HybridGaussianProductFactor::equals(
const HybridGaussianProductFactor& other, double tol) const {
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
});
}

/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
auto emptyGaussian = [](const Y& y) {
bool hasNull =
std::any_of(y.first.begin(),
y.first.end(),
std::any_of(y.first.begin(), y.first.end(),
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
};
Expand Down
7 changes: 1 addition & 6 deletions gtsam/hybrid/HybridGaussianProductFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,7 @@ class HybridGaussianProductFactor
* @return true if equal, false otherwise
*/
bool equals(const HybridGaussianProductFactor& other,
double tol = 1e-9) const {
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) &&
std::abs(a.second - b.second) < tol;
});
}
double tol = 1e-9) const;

/// @}

Expand Down

0 comments on commit 21b4c4c

Please sign in to comment.