Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HybridFactorGraph::continuousKeys #1852

Merged
merged 2 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactor.h>

namespace gtsam {

Expand Down Expand Up @@ -58,6 +59,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
}
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
} else if (auto p = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
}
}
return keys;
Expand Down
9 changes: 2 additions & 7 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ HybridGaussianConditional::HybridGaussianConditional(
conditionals_(conditionals),
negLogConstant_(helper.minNegLogConstant) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
Expand Down Expand Up @@ -242,13 +244,6 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param discreteProbs The probabilities of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
Expand Down
8 changes: 7 additions & 1 deletion gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,13 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

//// Get the pruner functor from pruned discrete probabilities.
/**
* @brief Get the pruner function from discrete probabilities.
*
* @param discreteProbs The probabilities of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &prunedProbabilities);
Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,19 @@ HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
: helper.factorsTree) {}

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
Expand Down Expand Up @@ -223,6 +226,7 @@ double HybridGaussianFactor::potentiallyPrunedComponentError(
return std::numeric_limits<double>::max();
}
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
Expand Down
Loading