From 8b968c14018caab9daa5838a03d0c57829c6eb21 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 6 Nov 2024 09:08:16 -0500 Subject: [PATCH 1/4] use visitWith to not create a new tree --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 45574f6414..4e7a7342ef 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -239,7 +239,7 @@ namespace gtsam { }; // Go through the tree - this->apply(op); + this->visitWith(op); return probs; } From d21f191219a945c91546e8d77bb0badc3f877446 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 6 Nov 2024 15:06:46 -0500 Subject: [PATCH 2/4] use a fixed size min-heap to find the pruning threshold --- gtsam/discrete/DecisionTreeFactor.cpp | 57 ++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4e7a7342ef..9b541bbf02 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -353,18 +353,57 @@ namespace gtsam { DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { const size_t N = maxNrAssignments; - // Get the probabilities in the decision tree so we can threshold. - std::vector probabilities = this->probabilities(); + // Set of all keys + std::set allKeys(keys().begin(), keys().end()); + std::vector min_heap; - // The number of probabilities can be lower than max_leaves - if (probabilities.size() <= N) { - return *this; - } + auto op = [&](const Assignment& a, double p) { + // Get all the keys in the current assignment + std::set assignment_keys; + for (auto&& [k, _] : a) { + assignment_keys.insert(k); + } - std::sort(probabilities.begin(), probabilities.end(), - std::greater{}); + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); + + // Compute the total number of assignments in the (pruned) subtree + size_t nrAssignments = 1; + for (auto&& k : diff) { + nrAssignments *= cardinalities_.at(k); + } + + if (min_heap.empty()) { + for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { + min_heap.push_back(p); + } + std::make_heap(min_heap.begin(), min_heap.end(), + std::greater{}); + + } else { + // If p is larger than the smallest element, + // then we insert into the max heap. + if (p > min_heap.at(0)) { + for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { + if (min_heap.size() == N) { + std::pop_heap(min_heap.begin(), min_heap.end(), + std::greater{}); + min_heap.pop_back(); + } + min_heap.push_back(p); + std::make_heap(min_heap.begin(), min_heap.end(), + std::greater{}); + } + } + } + return p; + }; + this->visitWith(op); - double threshold = probabilities[N - 1]; + double threshold = min_heap.at(0); // Now threshold the decision tree size_t total = 0; From 9666725473c9dfbf7b16f3447c400306ace9f783 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 6 Nov 2024 16:37:04 -0500 Subject: [PATCH 3/4] implement a min-heap to record the top N probabilities for pruning --- gtsam/discrete/DecisionTreeFactor.cpp | 74 ++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9b541bbf02..c8efc5fa5a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -349,13 +349,67 @@ namespace gtsam { : DiscreteFactor(keys.indices(), keys.cardinalities()), AlgebraicDecisionTree(keys, table) {} + /** + * @brief Min-Heap class to help with pruning. + * The `top` element is always the smallest value. + */ + class MinHeap { + std::vector v_; + + public: + /// Default constructor + MinHeap() {} + + /// Push value onto the heap + void push(double x) { + v_.push_back(x); + std::make_heap(v_.begin(), v_.end(), std::greater{}); + } + + /// Push value `x`, `n` number of times. + void push(double x, size_t n) { + v_.insert(v_.end(), n, x); + std::make_heap(v_.begin(), v_.end(), std::greater{}); + } + + /// Pop the top value of the heap. + double pop() { + std::pop_heap(v_.begin(), v_.end(), std::greater{}); + double x = v_.back(); + v_.pop_back(); + return x; + } + + /// Return the top value of the heap without popping it. + double top() { return v_.at(0); } + + /** + * @brief Print the heap as a sequence. + * + * @param s A string to prologue the output. + */ + void print(const std::string& s = "") { + std::cout << (s.empty() ? "" : s + " "); + for (size_t i = 0; i < v_.size() - 1; i++) { + std::cout << v_.at(i) << ","; + } + std::cout << v_.at(v_.size() - 1) << std::endl; + } + + /// Return true if heap is empty. + bool empty() const { return v_.empty(); } + + /// Return the size of the heap. + size_t size() const { return v_.size(); } + }; + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { const size_t N = maxNrAssignments; // Set of all keys std::set allKeys(keys().begin(), keys().end()); - std::vector min_heap; + MinHeap min_heap; auto op = [&](const Assignment& a, double p) { // Get all the keys in the current assignment @@ -377,25 +431,17 @@ namespace gtsam { } if (min_heap.empty()) { - for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { - min_heap.push_back(p); - } - std::make_heap(min_heap.begin(), min_heap.end(), - std::greater{}); + min_heap.push(p, std::min(nrAssignments, N)); } else { // If p is larger than the smallest element, // then we insert into the max heap. - if (p > min_heap.at(0)) { + if (p > min_heap.top()) { for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { if (min_heap.size() == N) { - std::pop_heap(min_heap.begin(), min_heap.end(), - std::greater{}); - min_heap.pop_back(); + min_heap.pop(); } - min_heap.push_back(p); - std::make_heap(min_heap.begin(), min_heap.end(), - std::greater{}); + min_heap.push(p); } } } @@ -403,7 +449,7 @@ namespace gtsam { }; this->visitWith(op); - double threshold = min_heap.at(0); + double threshold = min_heap.top(); // Now threshold the decision tree size_t total = 0; From ae43b2ade7a6f88332146aca59a27a17f8eb17b8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 6 Nov 2024 19:23:26 -0500 Subject: [PATCH 4/4] make MinHeap more efficient by calling push_heap instead of make_heap --- gtsam/discrete/DecisionTreeFactor.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c8efc5fa5a..caedab713c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -363,13 +363,15 @@ namespace gtsam { /// Push value onto the heap void push(double x) { v_.push_back(x); - std::make_heap(v_.begin(), v_.end(), std::greater{}); + std::push_heap(v_.begin(), v_.end(), std::greater{}); } /// Push value `x`, `n` number of times. void push(double x, size_t n) { - v_.insert(v_.end(), n, x); - std::make_heap(v_.begin(), v_.end(), std::greater{}); + for (size_t i = 0; i < n; ++i) { + v_.push_back(x); + std::push_heap(v_.begin(), v_.end(), std::greater{}); + } } /// Pop the top value of the heap. @@ -390,10 +392,11 @@ namespace gtsam { */ void print(const std::string& s = "") { std::cout << (s.empty() ? "" : s + " "); - for (size_t i = 0; i < v_.size() - 1; i++) { - std::cout << v_.at(i) << ","; + for (size_t i = 0; i < v_.size(); i++) { + std::cout << v_.at(i); + if (v_.size() > 1 && i < v_.size() - 1) std::cout << ", "; } - std::cout << v_.at(v_.size() - 1) << std::endl; + std::cout << std::endl; } /// Return true if heap is empty.