Skip to content

Commit

Permalink
Merge pull request #1881 from borglab/feature/no_conditionals
Browse files Browse the repository at this point in the history
Significant speedup
  • Loading branch information
dellaert authored Oct 23, 2024
2 parents 5b318fc + cbb0a30 commit 366b514
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 116 deletions.
63 changes: 58 additions & 5 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <optional>
#include <set>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

Expand Down Expand Up @@ -286,6 +287,10 @@ namespace gtsam {
return branches_;
}

std::vector<NodePtr>& branches() {
return branches_;
}

/** add a branch: TODO merge into constructor */
void push_back(NodePtr&& node) {
// allSame_ is restricted to leaf nodes in a decision tree
Expand Down Expand Up @@ -482,8 +487,8 @@ namespace gtsam {
/****************************************************************************/
// DecisionTree
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {}
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}

template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
Expand Down Expand Up @@ -554,6 +559,36 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label);
}

/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Unary& op,
DecisionTree&& other) noexcept
: root_(std::move(other.root_)) {
// Apply the unary operation directly to each leaf in the tree
if (root_) {
// Define a helper function to traverse and apply the operation
struct ApplyUnary {
const Unary& op;
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
// Apply the unary operation to the leaf's constant value
leaf->constant_ = op(leaf->constant_);
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
// Recurse into the choice branches
for (NodePtr& branch : choice->branches()) {
(*this)(branch);
}
}
}
};

ApplyUnary applyUnary{op};
applyUnary(root_);
}
// Reset the other tree's root to nullptr to avoid dangling references
other.root_ = nullptr;
}

/****************************************************************************/
template <typename L, typename Y>
template <typename X, typename Func>
Expand Down Expand Up @@ -694,7 +729,7 @@ namespace gtsam {
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
Expand All @@ -710,7 +745,7 @@ namespace gtsam {

// If leaf, apply unary conversion "op" and create a unique leaf.
using LXLeaf = typename DecisionTree<L, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

Expand Down Expand Up @@ -951,11 +986,16 @@ namespace gtsam {
return root_->equals(*other.root_);
}

/****************************************************************************/
template<typename L, typename Y>
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
if (root_ == nullptr)
throw std::invalid_argument(
"DecisionTree::operator() called on empty tree");
return root_->operator ()(x);
}

/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty:
Expand All @@ -966,6 +1006,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op));
}

/****************************************************************************/
/// Apply unary operator with assignment
template <typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
Expand Down Expand Up @@ -1049,6 +1090,18 @@ namespace gtsam {
return ss.str();
}

/******************************************************************************/
/******************************************************************************/
template <typename L, typename Y>
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const {
using AB = std::pair<A, B>;
const DecisionTree<L, AB> ab(*this, AB_of_Y);
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
return {a, b};
}

/******************************************************************************/

} // namespace gtsam
33 changes: 27 additions & 6 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ namespace gtsam {

/** ------------------------ Node base class --------------------------- */
struct Node {
using Ptr = std::shared_ptr<const Node>;
using Ptr = std::shared_ptr<Node>;

#ifdef DT_DEBUG_MEMORY
static int nrNodes;
Expand Down Expand Up @@ -156,10 +156,10 @@ namespace gtsam {
template <typename It, typename ValueIt>
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);

/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
/**
* Internal helper function to create a tree from keys, cardinalities, and Y
* values. Calls `build` which builds the tree bottom-up, before we prune in
* a top-down fashion.
*/
template <typename It, typename ValueIt>
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
Expand Down Expand Up @@ -228,6 +228,15 @@ namespace gtsam {
DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f1);

/**
* @brief Move constructor for DecisionTree. Very efficient as does not
* allocate anything, just changes in-place. But `other` is consumed.
*
* @param op The unary operation to apply to the moved DecisionTree.
* @param other The DecisionTree to move from, will be empty afterwards.
*/
DecisionTree(const Unary& op, DecisionTree&& other) noexcept;

/**
* @brief Convert from a different value type.
*
Expand All @@ -239,7 +248,7 @@ namespace gtsam {
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);

/**
* @brief Convert from a different value type X to value type Y, also transate
* @brief Convert from a different value type X to value type Y, also translate
* labels via map from type M to L.
*
* @tparam M Previous label type.
Expand Down Expand Up @@ -406,6 +415,18 @@ namespace gtsam {
const ValueFormatter& valueFormatter,
bool showZero = true) const;

/**
* @brief Convert into two trees with value types A and B.
*
* @tparam A First new value type.
* @tparam B Second new value type.
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
* @return A pair of DecisionTrees with value types A and B respectively.
*/
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const;

/// @name Advanced Interface
/// @{

Expand Down
55 changes: 54 additions & 1 deletion gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

/*
* @file testDecisionTree.cpp
* @brief Develop DecisionTree
* @brief DecisionTree unit tests
* @author Frank Dellaert
* @author Can Erdogan
* @date Jan 30, 2012
Expand Down Expand Up @@ -108,6 +108,7 @@ struct DT : public DecisionTree<string, int> {
std::cout << s;
Base::print("", keyFormatter, valueFormatter);
}

/// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; };
Expand Down Expand Up @@ -271,6 +272,58 @@ TEST(DecisionTree, Example) {
DOT(acnotb);
}

/* ************************************************************************** */
// Test that we can create two trees out of one, using a function that returns a pair.
TEST(DecisionTree, Split) {
// Create labels
string A("A"), B("B");

// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));

// Define a function that returns an int/bool pair
auto split_function = [](const int& value) -> std::pair<int, bool> {
return {value*3, value*3 % 2 == 0};
};

// Split the original tree into two new trees
auto [la,lb] = original.split<int,bool>(split_function);

// Check the first resulting tree
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));

// Check the second resulting tree
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
}


/* ************************************************************************** */
// Test that we can create a tree by modifying an rvalue.
TEST(DecisionTree, Consume) {
// Create labels
string A("A"), B("B");

// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));

DT modified([](int i){return i*2;}, std::move(original));

// Check the first resulting tree
EXPECT_LONGS_EQUAL(2, modified(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, modified(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(6, modified(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(8, modified(Assignment<string>{{A, 1}, {B, 1}}));

// Check original was moved
EXPECT(original.root_ == nullptr);
}

/* ************************************************************************** */
// test Conversion of values
bool bool_of_int(const int& y) { return y != 0; };
Expand Down
Loading

0 comments on commit 366b514

Please sign in to comment.