diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 45574f6414..caedab713c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -239,7 +239,7 @@ namespace gtsam { }; // Go through the tree - this->apply(op); + this->visitWith(op); return probs; } @@ -349,22 +349,110 @@ namespace gtsam { : DiscreteFactor(keys.indices(), keys.cardinalities()), AlgebraicDecisionTree(keys, table) {} + /** + * @brief Min-Heap class to help with pruning. + * The `top` element is always the smallest value. + */ + class MinHeap { + std::vector v_; + + public: + /// Default constructor + MinHeap() {} + + /// Push value onto the heap + void push(double x) { + v_.push_back(x); + std::push_heap(v_.begin(), v_.end(), std::greater{}); + } + + /// Push value `x`, `n` number of times. + void push(double x, size_t n) { + for (size_t i = 0; i < n; ++i) { + v_.push_back(x); + std::push_heap(v_.begin(), v_.end(), std::greater{}); + } + } + + /// Pop the top value of the heap. + double pop() { + std::pop_heap(v_.begin(), v_.end(), std::greater{}); + double x = v_.back(); + v_.pop_back(); + return x; + } + + /// Return the top value of the heap without popping it. + double top() { return v_.at(0); } + + /** + * @brief Print the heap as a sequence. + * + * @param s A string to prologue the output. + */ + void print(const std::string& s = "") { + std::cout << (s.empty() ? "" : s + " "); + for (size_t i = 0; i < v_.size(); i++) { + std::cout << v_.at(i); + if (v_.size() > 1 && i < v_.size() - 1) std::cout << ", "; + } + std::cout << std::endl; + } + + /// Return true if heap is empty. + bool empty() const { return v_.empty(); } + + /// Return the size of the heap. + size_t size() const { return v_.size(); } + }; + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { const size_t N = maxNrAssignments; - // Get the probabilities in the decision tree so we can threshold. - std::vector probabilities = this->probabilities(); + // Set of all keys + std::set allKeys(keys().begin(), keys().end()); + MinHeap min_heap; - // The number of probabilities can be lower than max_leaves - if (probabilities.size() <= N) { - return *this; - } + auto op = [&](const Assignment& a, double p) { + // Get all the keys in the current assignment + std::set assignment_keys; + for (auto&& [k, _] : a) { + assignment_keys.insert(k); + } + + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); + + // Compute the total number of assignments in the (pruned) subtree + size_t nrAssignments = 1; + for (auto&& k : diff) { + nrAssignments *= cardinalities_.at(k); + } - std::sort(probabilities.begin(), probabilities.end(), - std::greater{}); + if (min_heap.empty()) { + min_heap.push(p, std::min(nrAssignments, N)); + + } else { + // If p is larger than the smallest element, + // then we insert into the max heap. + if (p > min_heap.top()) { + for (size_t i = 0; i < std::min(nrAssignments, N); ++i) { + if (min_heap.size() == N) { + min_heap.pop(); + } + min_heap.push(p); + } + } + } + return p; + }; + this->visitWith(op); - double threshold = probabilities[N - 1]; + double threshold = min_heap.top(); // Now threshold the decision tree size_t total = 0;