Skip to content

Commit

Permalink
Merge pull request #1991 from borglab/fix/refactor_marginals
Browse files Browse the repository at this point in the history
Refactor jointBayesNet
  • Loading branch information
dellaert authored Jan 24, 2025
2 parents 7b56d66 + 21cb31e commit cfb9ea7
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 144 deletions.
11 changes: 10 additions & 1 deletion gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ class DiscreteBayesTreeClique {

class DiscreteBayesTree {
DiscreteBayesTree();
void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree);
void addClique(const gtsam::DiscreteBayesTreeClique* clique);
void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique);

void print(string s = "DiscreteBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand All @@ -276,6 +280,12 @@ class DiscreteBayesTree {
size_t size() const;
bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const;
const DiscreteBayesTreeClique* clique(size_t j) const;
size_t numCachedSeparatorMarginals() const;

gtsam::DiscreteConditional marginalFactor(size_t key) const;
gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const;
gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const;

double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
Expand All @@ -285,7 +295,6 @@ class DiscreteBayesTree {
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
33 changes: 20 additions & 13 deletions gtsam/discrete/tests/testDiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),

using ADT = AlgebraicDecisionTree<Key>;

// Function to construct the Asia example
DiscreteBayesNet constructAsiaExample() {
DiscreteBayesNet asia;

asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version

asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");

asia.add((Either | Tuberculosis, LungCancer) = "F T T T");

asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");

return asia;
}

/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
Expand Down Expand Up @@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) {

/* ************************************************************************* */
TEST(DiscreteBayesNet, Asia) {
DiscreteBayesNet asia;

asia.add(Asia, "99/1");
asia.add(Smoking % "50/50"); // Signature version

asia.add(Tuberculosis | Asia = "99/1 95/5");
asia.add(LungCancer | Smoking = "99/1 90/10");
asia.add(Bronchitis | Smoking = "70/30 40/60");

asia.add((Either | Tuberculosis, LungCancer) = "F T T T");

asia.add(XRay | Either = "95/5 2/98");
asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
DiscreteBayesNet asia = constructAsiaExample();

// Convert to factor graph
DiscreteFactorGraph fg(asia);
Expand Down
48 changes: 43 additions & 5 deletions gtsam/discrete/tests/testDiscreteBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) {
shortcut = clique->shortcut(R, EliminateDiscrete);
DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9);

// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
if (debug) {
if (debug) {
// print all shortcuts to root
DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes();
for (auto clique : cliques) {
DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete);
clique.second->conditional_->printSignature();
shortcut.print("shortcut:");
}
Expand All @@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) {
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;

// Caclulate marginals with brute force enumeration.
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
Expand Down Expand Up @@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) {
TEST(DiscreteBayesTree, Dot) {
TestFixture self;
std::string actual = self.bayesTree->dot();
// print actual:
if (debug) std::cout << actual << std::endl;
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n"
Expand Down Expand Up @@ -369,6 +372,41 @@ TEST(DiscreteBayesTree, Lookup) {
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}

/* ************************************************************************* */
// Test creating a Bayes tree directly from cliques
TEST(DiscreteBayesTree, DirectFromCliques) {
// Create a BayesNet
DiscreteBayesNet bayesNet;
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
bayesNet.add(A % "1/3");
bayesNet.add(B | A = "1/3 3/1");
bayesNet.add(C | B = "3/1 3/1");

// Create cliques directly
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(C | B = "3/1 3/1"));
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(B | A = "1/3 3/1"));
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(A % "1/3"));

// Create a BayesTree
DiscreteBayesTree bayesTree;
bayesTree.insertRoot(clique2);
bayesTree.addClique(clique1, clique2);
bayesTree.addClique(clique0, clique1);

// Check that the BayesTree is correct
DiscreteValues values;
values[A.first] = 1;
values[B.first] = 1;
values[C.first] = 1;

// Regression
double expected = .046875;
DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9);
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
155 changes: 65 additions & 90 deletions gtsam/inference/BayesTree-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <fstream>
#include <queue>
#include <cassert>
#include <unordered_set>

namespace gtsam {

/* ************************************************************************* */
Expand Down Expand Up @@ -335,112 +337,85 @@ namespace gtsam {
}

/* ************************************************************************* */
template<class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet
BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
{
// Find the lowest common ancestor of two cliques
template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
// Collect all ancestors of C1
std::unordered_set<std::shared_ptr<CLIQUE>> ancestors;
for (auto p = C1; p; p = p->parent()) {
ancestors.insert(p);
}

// Find the first common ancestor in C2's lineage
std::shared_ptr<CLIQUE> B;
for (auto p = C2; p; p = p->parent()) {
if (ancestors.count(p)) {
return p; // Return the common ancestor when found
}
}

return nullptr; // Return nullptr if no common ancestor is found
}

/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B)
template <class CLIQUE>
static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
gttic(Full_root_factoring);

// Get the shortcut P(S|B)
auto p_S_B = p_F_S->shortcut(B, eliminate);

// Compute S\B
KeyVector S_setminus_B = p_F_S->separator_setminus_B(B);

// Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B)
auto [bayesTree, fg] =
typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal(
Ordering(S_setminus_B), eliminate);
return bayesTree;
}

/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
Key j1, Key j2, const Eliminate& eliminate) const {
gttic(BayesTree_jointBayesNet);
// get clique C1 and C2
sharedClique C1 = (*this)[j1], C2 = (*this)[j2];

gttic(Lowest_common_ancestor);
// Find lowest common ancestor clique
sharedClique B; {
// Build two paths to the root
FastList<sharedClique> path1, path2; {
sharedClique p = C1;
while(p) {
path1.push_front(p);
p = p->parent();
}
} {
sharedClique p = C2;
while(p) {
path2.push_front(p);
p = p->parent();
}
}
// Find the path intersection
typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
if(*p1 == *p2)
B = *p1;
while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
B = *p1;
++p1;
++p2;
}
}
gttoc(Lowest_common_ancestor);
// Find the lowest common ancestor clique
auto B = findLowestCommonAncestor(C1, C2);

// Build joint on all involved variables
FactorGraphType p_BC1C2;

if(B)
{
if (B) {
// Compute marginal on lowest common ancestor clique
gttic(LCA_marginal);
FactorGraphType p_B = B->marginal2(function);
gttoc(LCA_marginal);

// Compute shortcuts of the requested cliques given the lowest common ancestor
gttic(Clique_shortcuts);
BayesNetType p_C1_Bred = C1->shortcut(B, function);
BayesNetType p_C2_Bred = C2->shortcut(B, function);
gttoc(Clique_shortcuts);

// Factor the shortcuts to be conditioned on the full root
// Get the set of variables to eliminate, which is C1\B.
gttic(Full_root_factoring);
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
KeyVector C1_minus_B; {
KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
for(const Key j: *B->conditional()) {
C1_minus_B_set.erase(j); }
C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
}
// Factor into C1\B | B.
p_C1_B =
FactorGraphType(p_C1_Bred)
.eliminatePartialMultifrontal(Ordering(C1_minus_B), function)
.first;
}
std::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
KeyVector C2_minus_B; {
KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
for(const Key j: *B->conditional()) {
C2_minus_B_set.erase(j); }
C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
}
// Factor into C2\B | B.
p_C2_B =
FactorGraphType(p_C2_Bred)
.eliminatePartialMultifrontal(Ordering(C2_minus_B), function)
.first;
}
gttoc(Full_root_factoring);
FactorGraphType p_B = B->marginal2(eliminate);

// Factor the shortcuts to be conditioned on lowest common ancestor
auto p_C1_B = factorInto(C1, B, eliminate);
auto p_C2_B = factorInto(C2, B, eliminate);

gttic(Variable_joint);
p_BC1C2.push_back(p_B);
p_BC1C2.push_back(*p_C1_B);
p_BC1C2.push_back(*p_C2_B);
if(C1 != B)
p_BC1C2.push_back(C1->conditional());
if(C2 != B)
p_BC1C2.push_back(C2->conditional());
gttoc(Variable_joint);
}
else
{
// The nodes have no common ancestor, they're in different trees, so they're joint is just the
// product of their marginals.
gttic(Disjoint_marginals);
p_BC1C2.push_back(C1->marginal2(function));
p_BC1C2.push_back(C2->marginal2(function));
gttoc(Disjoint_marginals);
if (C1 != B) p_BC1C2.push_back(C1->conditional());
if (C2 != B) p_BC1C2.push_back(C2->conditional());
} else {
// The nodes have no common ancestor, they're in different trees, so
// they're joint is just the product of their marginals.
p_BC1C2.push_back(C1->marginal2(eliminate));
p_BC1C2.push_back(C2->marginal2(eliminate));
}

// now, marginalize out everything that is not variable j1 or j2
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
}

/* ************************************************************************* */
Expand Down
22 changes: 12 additions & 10 deletions gtsam/inference/BayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ namespace gtsam {
/** Assignment operator */
This& operator=(const This& other);

public:

/// @name Testable
/// @{

/** check equality */
bool equals(const This& other, double tol = 1e-9) const;

public:
/** print */
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
Expand Down Expand Up @@ -185,18 +186,19 @@ namespace gtsam {
*/
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/// @name Graph Display
/// @{
/// @}
/// @name Graph Display
/// @{

/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format, stream version.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Output to graphviz format string.
std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// output to file with graphviz format.
void saveGraph(const std::string& filename,
/// output to file with graphviz format.
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
Expand Down
Loading

0 comments on commit cfb9ea7

Please sign in to comment.