Skip to content

Commit

Permalink
Merge pull request #1865 from borglab/feature/no_hiding-2
Browse files Browse the repository at this point in the history
Updates to `No Hiding` PR
  • Loading branch information
dellaert authored Oct 9, 2024
2 parents 9f7ccbb + 436524a commit 59f97d6
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 126 deletions.
3 changes: 1 addition & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/linear/GaussianJunctionTree.h>

#include <memory>

#include "gtsam/hybrid/HybridConditional.h"

namespace gtsam {

// Instantiate base class
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file HybridConditional.cpp
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/

#include <gtsam/hybrid/HybridConditional.h>
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file HybridConditional.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/

#pragma once
Expand Down
18 changes: 10 additions & 8 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct HybridGaussianConditional::Helper {
explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr& gc) -> GaussianFactorValuePair {
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals();
double value = gc->negLogConstant();
Expand All @@ -97,10 +97,10 @@ struct HybridGaussianConditional::Helper {

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys& discreteParents, const Helper& helper)
const DiscreteKeys &discreteParents, const Helper &helper)
: BaseFactor(discreteParents,
FactorValuePairs(helper.pairs,
[&](const GaussianFactorValuePair&
[&](const GaussianFactorValuePair &
pair) { // subtract minNegLogConstant
return GaussianFactorValuePair{
pair.first,
Expand Down Expand Up @@ -183,10 +183,12 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,

// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(
e->conditionals_, [tol](const auto &f1, const auto &f2) {
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
});
}

/* *******************************************************************************/
Expand Down Expand Up @@ -225,7 +227,7 @@ KeyVector HybridGaussianConditional::continuousParents() const {
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
continuousParentKeys.end());
}
return continuousParentKeys;
}
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h>

Expand Down
58 changes: 32 additions & 26 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
// Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs(
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
[](const auto& f) {
[](const sharedFactor& f) {
return std::pair{f,
f ? 0.0 : std::numeric_limits<double>::infinity()};
});
Expand All @@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
const std::vector<GaussianFactorValuePair>& factorPairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto& pair : factorPairs) {
for (const GaussianFactorValuePair& pair : factorPairs) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
break;
Expand Down Expand Up @@ -93,35 +93,36 @@ struct HybridGaussianFactor::ConstructorHelper {
};

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.pairs) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
const DiscreteKey& discreteKey,
const std::vector<GaussianFactor::shared_ptr>& factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
const DiscreteKey& discreteKey,
const std::vector<GaussianFactorValuePair>& factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}

HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors)
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
const FactorValuePairs& factors)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}

/* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
const This* e = dynamic_cast<const This*>(&lf);
if (e == nullptr) return false;

// This will return false if either factors_ is empty or e->factors_ is
// empty, but not if both are empty or both are not empty:
if (factors_.empty() ^ e->factors_.empty()) return false;

// Check the base and the factors:
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
auto f1 = pair1.first, f2 = pair2.first;
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
return match && gtsam::equal(pair1.second, pair2.second, tol);
Expand All @@ -130,18 +131,17 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
}

/* *******************************************************************************/
void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
void HybridGaussianFactor::print(const std::string& s,
const KeyFormatter& formatter) const {
std::cout << (s.empty() ? "" : s + "\n");
HybridFactor::print("", formatter);
std::cout << "{\n";
if (factors_.empty()) {
std::cout << " empty" << std::endl;
} else {
factors_.print(
"",
[&](Key k) { return formatter(k); },
[&](const auto& pair) -> std::string {
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactorValuePair& pair) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (pair.first) {
Expand All @@ -158,7 +158,7 @@ void HybridGaussianFactor::print(const std::string &s,

/* *******************************************************************************/
GaussianFactorValuePair HybridGaussianFactor::operator()(
const DiscreteValues &assignment) const {
const DiscreteValues& assignment) const {
return factors_(assignment);
}

Expand All @@ -169,18 +169,25 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
// scalar.
return {{factors_,
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> {
[](const GaussianFactorValuePair& pair)
-> std::pair<GaussianFactorGraph, double> {
return {GaussianFactorGraph{pair.first}, pair.second};
}}};
}

/* *******************************************************************************/
inline static double PotentiallyPrunedComponentError(
const GaussianFactorValuePair& pair, const VectorValues& continuousValues) {
return pair.first ? pair.first->error(continuousValues) + pair.second
: std::numeric_limits<double>::infinity();
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
const VectorValues& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const auto& pair) {
return pair.first ? pair.first->error(continuousValues) + pair.second
: std::numeric_limits<double>::infinity();
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
return PotentiallyPrunedComponentError(pair, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
Expand All @@ -189,9 +196,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues& values) const {
// Directly index to get the component, no need to build the whole tree.
const auto pair = factors_(values.discrete());
return pair.first ? pair.first->error(values.continuous()) + pair.second
: std::numeric_limits<double>::infinity();
const GaussianFactorValuePair pair = factors_(values.discrete());
return PotentiallyPrunedComponentError(pair, values.continuous());
}

} // namespace gtsam
17 changes: 8 additions & 9 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ using GaussianFactorValuePair = std::pair<GaussianFactor::shared_ptr, double>;
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
public:
public:
using Base = HybridFactor;
using This = HybridGaussianFactor;
using shared_ptr = std::shared_ptr<This>;
Expand All @@ -68,11 +68,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;

private:
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
FactorValuePairs factors_;

public:
public:
/// @name Constructors
/// @{

Expand Down Expand Up @@ -120,9 +120,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {

bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

void
print(const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
void print(const std::string &s = "", const KeyFormatter &formatter =
DefaultKeyFormatter) const override;

/// @}
/// @name Standard API
Expand All @@ -138,8 +137,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key>
errorTree(const VectorValues &continuousValues) const override;
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const override;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand All @@ -159,7 +158,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {

/// @}

private:
private:
/**
* @brief Helper function to augment the [A|b] matrices in the factor
* components with the additional scalar values. This is done by storing the
Expand Down
Loading

0 comments on commit 59f97d6

Please sign in to comment.