Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory consumption during discrete pruning #1898

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 98 additions & 10 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ namespace gtsam {
};

// Go through the tree
this->apply(op);
this->visitWith(op);

return probs;
}
Expand Down Expand Up @@ -349,22 +349,110 @@ namespace gtsam {
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}

/**
* @brief Min-Heap class to help with pruning.
* The `top` element is always the smallest value.
*/
class MinHeap {
std::vector<double> v_;

public:
/// Default constructor
MinHeap() {}

/// Push value onto the heap
void push(double x) {
v_.push_back(x);
std::push_heap(v_.begin(), v_.end(), std::greater<double>{});
}

/// Push value `x`, `n` number of times.
void push(double x, size_t n) {
for (size_t i = 0; i < n; ++i) {
v_.push_back(x);
std::push_heap(v_.begin(), v_.end(), std::greater<double>{});
}
}

/// Pop the top value of the heap.
double pop() {
std::pop_heap(v_.begin(), v_.end(), std::greater<double>{});
double x = v_.back();
v_.pop_back();
return x;
}

/// Return the top value of the heap without popping it.
double top() { return v_.at(0); }

/**
* @brief Print the heap as a sequence.
*
* @param s A string to prologue the output.
*/
void print(const std::string& s = "") {
std::cout << (s.empty() ? "" : s + " ");
for (size_t i = 0; i < v_.size(); i++) {
std::cout << v_.at(i);
if (v_.size() > 1 && i < v_.size() - 1) std::cout << ", ";
}
std::cout << std::endl;
}

/// Return true if heap is empty.
bool empty() const { return v_.empty(); }

/// Return the size of the heap.
size_t size() const { return v_.size(); }
};

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;

// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities = this->probabilities();
// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());
MinHeap min_heap;

// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
return *this;
}
auto op = [&](const Assignment<Key>& a, double p) {
// Get all the keys in the current assignment
std::set<Key> assignment_keys;
for (auto&& [k, _] : a) {
assignment_keys.insert(k);
}

// Find the keys missing in the assignment
std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(),
std::back_inserter(diff));

// Compute the total number of assignments in the (pruned) subtree
size_t nrAssignments = 1;
for (auto&& k : diff) {
nrAssignments *= cardinalities_.at(k);
}

std::sort(probabilities.begin(), probabilities.end(),
std::greater<double>{});
if (min_heap.empty()) {
min_heap.push(p, std::min(nrAssignments, N));

} else {
// If p is larger than the smallest element,
// then we insert into the max heap.
if (p > min_heap.top()) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
if (min_heap.size() == N) {
min_heap.pop();
}
min_heap.push(p);
}
}
}
return p;
};
this->visitWith(op);

double threshold = probabilities[N - 1];
double threshold = min_heap.top();

// Now threshold the decision tree
size_t total = 0;
Expand Down
Loading