Skip to content

Commit

Permalink
Trying to reduce lookup inputs count by grouping non-intersecting sel…
Browse files Browse the repository at this point in the history
…ectors.
  • Loading branch information
martun committed Dec 17, 2024
1 parent 01fead9 commit ce7f055
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 16 deletions.
175 changes: 159 additions & 16 deletions crypto3/libs/blueprint/include/nil/blueprint/bbf/gate_optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ namespace nil {
std::ostream& operator<<(std::ostream& os, const optimized_gates<FieldType>& gates) {
for (const auto& [selector, id]: gates.selectors_) {
auto iter = gates.constraint_list.find(id);
if (iter != gates.constraint_list.end()) {
os << "Selector #" << id << " " << selector << std::endl;
for (const auto &constraint : iter->second) {
os << constraint << std::endl;
}
os << "--------------------------------------------------------------" << std::endl;
}
if (iter != gates.constraint_list.end()) {
os << "Selector #" << id << " " << selector << std::endl;
for (const auto &constraint : iter->second) {
os << constraint << std::endl;
}
os << "--------------------------------------------------------------" << std::endl;
}
}
return os;
}
Expand Down Expand Up @@ -192,15 +192,158 @@ namespace nil {
optimized_gates<FieldType> result = context_to_gates();
// optimized_gates<FieldType> result = gates_storage_;
// std::cout << "Before: \n\n" << result << std::endl;
// optimize_selectors_by_shifting(result);
optimize_selectors_by_shifting(result);
optimize_lookups_by_grouping(result);
// std::cout << "After: \n\n" << result << std::endl;
return result;
}


private:

/** Brooks' algorithm for graph coloring.
* \param[in] adj - Adjacency list of the graph.
* \returns A map that maps the vertex id to its color.
*/
std::unordered_map<size_t, size_t> colorGraph(const std::vector<std::vector<size_t>>& adj) const {
size_t V = adj.size();
std::unordered_map<size_t, size_t> color;

// Sort vertices by degree in descending order
std::vector<size_t> vertices(V);
for (size_t i = 0; i < V; ++i)
vertices[i] = i;

std::sort(vertices.begin(), vertices.end(),
[&adj](size_t a, size_t b) { return adj[a].size() > adj[b].size(); });

// Color vertices
for (size_t u : vertices) {
// Collect colors of adjacent vertices
std::unordered_set<size_t> usedColors;
for (size_t v : adj[u]) {
if (color.count(v)) {
usedColors.insert(color[v]);
}
}

// Find first available color
size_t currColor = 0;
while (usedColors.count(currColor)) {
currColor++;
}
color[u] = currColor;
}

return color;
}

/** Creates and returns a graph in the form of an adjucency list.
*/
std::vector<std::vector<size_t>> create_selector_intersection_graph(
const optimized_gates<FieldType>& gates,
const std::vector<size_t>& used_selectors,
const std::map<size_t, size_t>& selector_id_to_index) {
std::vector<std::vector<size_t>> adj;
// Create the graph.
adj.resize(used_selectors.size());
for (const auto& [row_list1, selector_id1]: gates.selectors_) {
if (selector_id_to_index.find(selector_id1) == selector_id_to_index.end())
continue;
for (const auto& [row_list2, selector_id2]: gates.selectors_) {
if (selector_id2 >= selector_id1 || selector_id_to_index.find(selector_id2) == selector_id_to_index.end())
continue;
if (row_list1.intersects(row_list2)) {
// Add an edge.
adj[selector_id_to_index[selector_id1]].push_back(selector_id_to_index[selector_id2]);
adj[selector_id_to_index[selector_id2]].push_back(selector_id_to_index[selector_id1]);
}
}
}
return adj;
}

std::unordered_map<size_t, size_t> group_selectors(
const std::vector<std::vector<size_t>>& graph,
const std::unordered_map<size_t, size_t>& selector_id_to_index,
const std::vector<size_t>& used_selectors) {
std::vector<std::vector<size_t>> graph_subset = get_subset(
graph, selector_id_to_index, used_selectors);

std::unordered_map<size_t, size_t> coloring = colorGraph(graph_subset);

// Now run over the returned coloring and map it back.
std::unordered_map<size_t, size_t> result;
for (const auto& [id, group_id]: coloring) {
result[used_selectors[id]] = group_id;
}
return result;
}

/** This function tries to reduce the number of lookups by grouping them. If 2 lookups use non-intersecting
* selectors, they can be merged into 1 like.
* Imagine lookup inputs {L0 ... Lm} with selector s1, and {l0 ... lm} with selector s2, then we can merge them into
* lookup inputs { s1 * L0 + s2 * l0, ... , s1 * Lm + s2 * lm } with selector that selects all the rows.
* We cannot optimally group the selectors into the minimal number of groups, that's an NP-complete problem
* called graph coloring problem. We will use Brooks' algorithm, it's some simple heuristic thing.
*/
void optimize_lookups_by_grouping(optimized_gates<FieldType>& gates) {
std::vector<size_t> used_selectors;
std::unordered_map<size_t, size_t> selector_id_to_index;

for (const auto& [row_list, selector_id]: gates.selectors_) {
if (gates.lookup_constraints.find(selector_id) != gates.lookup_constraints.end()) {
used_selectors.push_back(selector_id);
selector_id_to_index[selector_id] = used_selectors.size() - 1;
}
}

// Create an adjacency list of the whole large graph, since taking intersections of selectors is not super fast.
std::vector<std::vector<size_t>> adj = create_selector_intersection_graph(
gates, used_selectors, selector_id_to_index);

// For each table, create the list of used selectors.
std::unordered_map<std::string, std::vector<size_t>> selectors_per_table;
for(const auto& [selector_id, lookup_list] : gates.lookup_constraints) {
for(const auto& single_lookup_constraint : lookup_list) {
const auto& table_name = single_lookup_constraint.first;
selectors_per_table[table_name].push_back(selector_id);
}
}

// For each table, group the selectors.

// Maps table name to a [map of selector id -> # of the group it belongs to].
std::unordered_map<std::string, std::unordered_map<size_t, size_t>> selector_groups;
std::unordered_map<std::string, std::unordered_map<size_t, size_t>> group_sizes;
for (const auto& [table_name, selectors] : selectors_per_table) {
// Maps group_id -> # of selectors in it.
selector_groups[table_name] = group_selectors(adj, selector_id_to_index, selectors);
// Count the size of each group, we need to not touch groups of size 1.
for (const auto& [selector_id, group_index]: selector_groups[table_name]) {
group_sizes[table_name][group_index]++;
}
}

std::unordered_map<std::string, std::unordered_map<size_t, typename context_type::lookup_input_constraints_type>> merged_lookups;
// Now merge all the lookups on selectors in the same group.
for (const auto& [selector_id, lookup_list] : gates.lookup_constraints) {
std::vector<lookup_constraint_type> lookup_gate;
for(const auto& single_lookup_constraint : lookup_list) {
const std::string& table_name = single_lookup_constraint.first;
size_t group_id = selector_groups[table_name][selector_id];
// If the group size is 1, don't touch it.
if (group_sizes[table_name][group_id] == 1)
merged_lookups[table_name][selector_id] = single_lookup_constraint.second;
else {
// selector_by_id is a non-existant map, we will add it later.
merged_lookups[table_name][PLONK_SPECIAL_SELECTOR_ALL_USABLE_ROWS_SELECTED] +=
selector_by_id[selector_id] * single_lookup_constraint.second;
}
}
}
}

/** This function tries to reduce the number of selectors required by rotating the constraints by +-1.
*/
void optimize_selectors_by_shifting(optimized_gates<FieldType>& gates) {
Expand All @@ -212,11 +355,11 @@ namespace nil {
std::vector<std::pair<size_t, int>> chosen_selectors = choose_selectors(
left_shifts, right_shifts);

//std::cout << "The following selector shifts were selected: \n";
//for (size_t i = 0; i < chosen_selectors.size(); ++i) {
// std::cout << "#" << i << " -> " << "#" << chosen_selectors[i].first << " shifted " << chosen_selectors[i].second << std::endl;
//}
//std::cout << std::endl;
//std::cout << "The following selector shifts were selected: \n";
//for (size_t i = 0; i < chosen_selectors.size(); ++i) {
// std::cout << "#" << i << " -> " << "#" << chosen_selectors[i].first << " shifted " << chosen_selectors[i].second << std::endl;
//}
//std::cout << std::endl;

// Maps the old selector ID to the new one, only for the selectors to be used.
std::map<size_t, size_t> new_selector_mapping;
Expand Down Expand Up @@ -298,7 +441,7 @@ namespace nil {

gates.constraint_list = std::move(result.constraint_list);
gates.lookup_constraints = std::move(result.lookup_constraints);
gates.selectors_ = result.selectors_;
gates.selectors_ = result.selectors_;
}

/**
Expand Down Expand Up @@ -329,7 +472,7 @@ namespace nil {
}

// For each node go to the left and right as far as possible.
// Then run over the chain of selectors and make the decisions.
// Then run over the chain of selectors and make the decisions.
for (size_t i = 0; i < N; ++i) {
// We may already have a decision for the current node.
if (chosen_shifts[i].first != -1)
Expand Down Expand Up @@ -377,7 +520,7 @@ namespace nil {
chosen_shifts[chain[j - 1]] = {chain[j - 1], 0};
chosen_shifts[chain[j]] = {chain[j - 1], -1};

// Check if chain[j - 2] can be skipped by rotating it to the right.
// Check if chain[j - 2] can be skipped by rotating it to the right.
if (j >= 2 && right_shifts[chain[j - 2]] == chain[j - 1]) {
chosen_shifts[chain[j - 2]] = {chain[j - 1], +1};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ namespace nil {
return used_rows_.none();
}

bool intersects(const row_selector& other) const {
return used_rows_.intersects(other.used_rows_);
}

// TODO: delete this, if not used.
/*
row_selector& operator|=(const row_selector& other) {
Expand Down

0 comments on commit ce7f055

Please sign in to comment.