Skip to content

Commit

Permalink
Merge pull request #1844 from borglab/feature/timeHybrid
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Sep 24, 2024
2 parents 6c97e4b + 82c25d8 commit e4ec8d3
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 136 deletions.
5 changes: 3 additions & 2 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
std::string value = valueFormatter(constant_);
const std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
Expand Down Expand Up @@ -306,7 +306,8 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
const std::string label = labelFormatter(label_);
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label
<< "\"]\n";
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {
Expand Down
30 changes: 14 additions & 16 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ namespace gtsam {
size_t i;
ADT result(*this);
for (i = 0; i < nrFrontals; i++) {
Key j = keys()[i];
Key j = keys_[i];
result = result.combine(j, cardinality(j), op);
}

// create new factor, note we start keys after nrFrontals
// Create new factor, note we start with keys after nrFrontals:
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
for (; i < keys_.size(); i++) {
Key j = keys_[i];
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return std::make_shared<DecisionTreeFactor>(dkeys, result);
Expand All @@ -179,24 +179,22 @@ namespace gtsam {
result = result.combine(j, cardinality(j), op);
}

// create new factor, note we collect keys that are not in frontalKeys
/*
Due to branch merging, the labels in `result` may be missing some keys
Create new factor, note we collect keys that are not in frontalKeys.
Due to branch merging, the labels in `result` may be missing some keys.
E.g. After branch merging, we may get a ADT like:
Leaf [2] 1.0204082
This is missing the key values used for branching.
Hence, code below traverses the original keys and omits those in
frontalKeys. We loop over cardinalities, which is O(n) even for a map, and
then "contains" is a binary search on a small vector.
*/
KeyVector difference, frontalKeys_(frontalKeys), keys_(keys());
// Get the difference of the frontalKeys and the factor keys using set_difference
std::sort(keys_.begin(), keys_.end());
std::sort(frontalKeys_.begin(), frontalKeys_.end());
std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(),
frontalKeys_.end(), back_inserter(difference));

DiscreteKeys dkeys;
for (Key key : difference) {
dkeys.push_back(DiscreteKey(key, cardinality(key)));
for (auto&& [key, cardinality] : cardinalities_) {
if (!frontalKeys.contains(key)) {
dkeys.push_back(DiscreteKey(key, cardinality));
}
}
return std::make_shared<DecisionTreeFactor>(dkeys, result);
}
Expand Down
121 changes: 28 additions & 93 deletions gtsam/discrete/tests/testAlgebraicDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>

using namespace std;
Expand Down Expand Up @@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
// instrumented operators
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
muls = 0;
adds = 0;
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << s << ": " << std::setw(3) << muls << " muls, " <<
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms."
<< endl;
cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
<< " adds" << endl;
#endif
resetCounts();
}
Expand Down Expand Up @@ -131,37 +126,35 @@ ADT create(const Signature& signature) {
static size_t count = 0;
const DiscreteKey& key = signature.key();
std::stringstream ss;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
<< key.first;
string DOTfile = ss.str();
dot(p, DOTfile);
return p;
}

/* ************************************************************************* */
namespace asiaCPTs {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);

ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
} // namespace asiaCPTs

/* ************************************************************************* */
// test Asia Joint
TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);

resetCounts();
gttic_(asiaCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(asiaCPTs);
tictoc_getNode(asiaCPTsNode, asiaCPTs);
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
tictoc_reset_();
printCounts("Asia CPTs");
using namespace asiaCPTs;

// Create joint
resetCounts();
gttic_(asiaJoint);
ADT joint = pA;
dot(joint, "Asia-A");
joint = apply(joint, pS, &mul);
Expand All @@ -183,11 +176,12 @@ TEST(ADT, joint) {
#else
EXPECT_LONGS_EQUAL(508, muls);
#endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint");
}

/* ************************************************************************* */
TEST(ADT, combine) {
using namespace asiaCPTs;

// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
ADT pASTL = pA;
Expand All @@ -203,13 +197,11 @@ TEST(ADT, joint) {
}

/* ************************************************************************* */
// test Inference with joint
// test Inference with joint, created using different ordering
TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);

resetCounts();
gttic_(infCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
Expand All @@ -218,15 +210,9 @@ TEST(ADT, inference) {
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(infCPTs);
tictoc_getNode(infCPTsNode, infCPTs);
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
tictoc_reset_();
// printCounts("Inference CPTs");

// Create joint
// Create joint, note different ordering than above: different tree!
resetCounts();
gttic_(asiaProd);
ADT joint = pA;
dot(joint, "Joint-Product-A");
joint = apply(joint, pS, &mul);
Expand All @@ -248,14 +234,9 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
tictoc_reset_();
printCounts("Asia product");

resetCounts();
gttic_(asiaSum);
ADT marginal = joint;
marginal = marginal.combine(X, &add_);
dot(marginal, "Joint-Sum-ADBLEST");
Expand All @@ -270,35 +251,23 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(240, (long)adds);
#endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
tictoc_reset_();
printCounts("Asia sum");
}

/* ************************************************************************* */
TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);

resetCounts();
gttic_(createCPTs);
ADT pS = create(S % "50/50");
ADT pT = create(T % "95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create(B | E = "1/8 7/9");
ADT pB = create(B | S = "70/30 40/60");
gttoc_(createCPTs);
tictoc_getNode(createCPTsNode, createCPTs);
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
tictoc_reset_();
// printCounts("Create CPTs");

// Create joint
resetCounts();
gttic_(asiaFG);
ADT fg = pS;
fg = apply(fg, pT, &mul);
fg = apply(fg, pL, &mul);
Expand All @@ -312,14 +281,9 @@ TEST(ADT, factor_graph) {
#else
EXPECT_LONGS_EQUAL(188, (long)muls);
#endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
tictoc_reset_();
printCounts("Asia FG");

resetCounts();
gttic_(marg);
fg = fg.combine(X, &add_);
dot(fg, "Marginalized-6X");
fg = fg.combine(T, &add_);
Expand All @@ -335,83 +299,54 @@ TEST(ADT, factor_graph) {
#else
LONGS_EQUAL(62, adds);
#endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
tictoc_reset_();
printCounts("marginalize");

// BLESTX

// Eliminate X
resetCounts();
gttic_(elimX);
ADT fE = pX;
dot(fE, "Eliminate-01-fEX");
fE = fE.combine(X, &add_);
dot(fE, "Eliminate-02-fE");
gttoc_(elimX);
tictoc_getNode(elimXNode, elimX);
elapsed = elimXNode->secs() + elimXNode->wall();
tictoc_reset_();
printCounts("Eliminate X");

// Eliminate T
resetCounts();
gttic_(elimT);
ADT fLE = pT;
fLE = apply(fLE, pE, &mul);
dot(fLE, "Eliminate-03-fLET");
fLE = fLE.combine(T, &add_);
dot(fLE, "Eliminate-04-fLE");
gttoc_(elimT);
tictoc_getNode(elimTNode, elimT);
elapsed = elimTNode->secs() + elimTNode->wall();
tictoc_reset_();
printCounts("Eliminate T");

// Eliminate S
resetCounts();
gttic_(elimS);
ADT fBL = pS;
fBL = apply(fBL, pL, &mul);
fBL = apply(fBL, pB, &mul);
dot(fBL, "Eliminate-05-fBLS");
fBL = fBL.combine(S, &add_);
dot(fBL, "Eliminate-06-fBL");
gttoc_(elimS);
tictoc_getNode(elimSNode, elimS);
elapsed = elimSNode->secs() + elimSNode->wall();
tictoc_reset_();
printCounts("Eliminate S");

// Eliminate E
resetCounts();
gttic_(elimE);
ADT fBL2 = fE;
fBL2 = apply(fBL2, fLE, &mul);
fBL2 = apply(fBL2, pD, &mul);
dot(fBL2, "Eliminate-07-fBLE");
fBL2 = fBL2.combine(E, &add_);
dot(fBL2, "Eliminate-08-fBL2");
gttoc_(elimE);
tictoc_getNode(elimENode, elimE);
elapsed = elimENode->secs() + elimENode->wall();
tictoc_reset_();
printCounts("Eliminate E");

// Eliminate L
resetCounts();
gttic_(elimL);
ADT fB = fBL;
fB = apply(fB, fBL2, &mul);
dot(fB, "Eliminate-09-fBL");
fB = fB.combine(L, &add_);
dot(fB, "Eliminate-10-fB");
gttoc_(elimL);
tictoc_getNode(elimLNode, elimL);
elapsed = elimLNode->secs() + elimLNode->wall();
tictoc_reset_();
printCounts("Eliminate L");
}

Expand Down
Loading

0 comments on commit e4ec8d3

Please sign in to comment.