Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common errorTree method and its use in HybridGaussianFactorGraph #1837

Merged
merged 11 commits into from
Sep 22, 2024
Prev Previous commit
Next Next commit
check for potentiall pruning
varunagrawal committed Sep 21, 2024
commit d81cd82b9a6e27e51104faf00c9915656a670799
21 changes: 18 additions & 3 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
@@ -151,21 +151,36 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
return {factors_, wrap};
}

/* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &values) const {
// Check if valid pointer
if (gf) {
return gf->error(values);
} else {
// If not valid, pointer, it means this component was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
return potentiallyPrunedComponentError(gf, values.continuous());
}

} // namespace gtsam
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
@@ -169,6 +169,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// @}

private:
/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;