Skip to content

Commit

Permalink
faster greedy dag extraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen committed Sep 12, 2023
1 parent 95e8a74 commit 4dcc3bc
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 5 deletions.
13 changes: 13 additions & 0 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ def process(js, extractors=[]):
assert len(extractors) == 2
e1, e2 = extractors

e1_cummulative=0
e2_cummulative=0

summaries = {}

for name, d in by_name.items():
try:
if d[e1]["tree"] != d[e2]["tree"]:
print(name, d[e1]["tree"], d[e2]["tree"]);

tree_ratio = d[e1]["tree"] / d[e2]["tree"]
dag_ratio = d[e1]["dag"] / d[e2]["dag"]
micros_ratio = max(1, d[e1]["micros"]) / max(1, d[e2]["micros"])

e1_cummulative += d[e1]["micros"];
e2_cummulative += d[e2]["micros"];

summaries[name] = {
"tree": tree_ratio,
"dag": dag_ratio,
Expand All @@ -47,6 +57,9 @@ def process(js, extractors=[]):
print(f"Error processing {name}")
raise e

print(f"Cummulative time for {e1}: {e1_cummulative/1000:.0f}ms")
print(f"Cummulative time for {e2}: {e2_cummulative/1000:.0f}ms")

print(f"{e1} / {e2}")

print("geo mean")
Expand Down
4 changes: 2 additions & 2 deletions src/extract/greedy_dag.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;

struct CostSet {
costs: im_rc::HashMap<ClassId, Cost>,
costs: std::collections::HashMap<ClassId, Cost>,
total: Cost,
choice: NodeId,
}
Expand All @@ -25,7 +25,7 @@ impl Extractor for GreedyDagExtractor {
'node_loop: for (node_id, node) in &nodes {
let cid = egraph.nid_to_cid(node_id);
let mut cost_set = CostSet {
costs: im_rc::HashMap::new(),
costs: std::collections::HashMap::new(),
total: Cost::default(),
choice: node_id.clone(),
};
Expand Down
201 changes: 201 additions & 0 deletions src/extract/greedy_dag_1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Calculates the cost where shared nodes are just costed once,
// For example (+ (* x x ) (* x x )) has one mulitplication
// included in the cost.

use super::*;

struct CostSet {
costs: std::collections::HashMap<ClassId, Cost>,
total: Cost,
choice: NodeId,
}

pub struct FasterGreedyDagExtractor;

impl FasterGreedyDagExtractor {
fn calculate_cost_set(
egraph: &EGraph,
node_id: NodeId,
costs: &HashMap<ClassId, CostSet>,
) -> CostSet {
let node = &egraph[&node_id];

let cid = egraph.nid_to_cid(&node_id);

let mut desc = 0;
let mut children_cost = Cost::default();
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
let cs = costs.get(child_cid).unwrap();
desc += cs.costs.len();
children_cost += cs.total;
}

let mut cost_set = CostSet {
costs: std::collections::HashMap::with_capacity(desc),
total: Cost::default(),
choice: node_id.clone(),
};

for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
cost_set
.costs
.extend(costs.get(child_cid).unwrap().costs.clone());
}

let contains = cost_set.costs.contains_key(&cid.clone());
cost_set.costs.insert(cid.clone(), node.cost); // this node.

if contains {
cost_set.total = INFINITY;
} else {
if cost_set.costs.len() == desc + 1 {
// No extra duplicates are found, so the cost is the current
// nodes cost + the children's cost.
cost_set.total = children_cost + node.cost;
} else {
cost_set.total = cost_set.costs.values().sum();
}
};

cost_set
}
}

impl FasterGreedyDagExtractor {
fn check(egraph: &EGraph, node_id: NodeId, costs: &HashMap<ClassId, CostSet>) {
let cid = egraph.nid_to_cid(&node_id);
let previous = costs.get(cid).unwrap().total;
let cs = Self::calculate_cost_set(egraph, node_id, costs);
println!("{} {}", cs.total, previous);
assert!(cs.total >= previous);
}
}

impl Extractor for FasterGreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
// 1. build map from class to parent nodes
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::default();
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);

for class in egraph.classes().values() {
parents.insert(class.id.clone(), Vec::new());
}
for class in egraph.classes().values() {
for node in &class.nodes {
for c in &egraph[node].children {
parents[n2c(c)].push(node.clone());
}
}
}

// 2. start analysis from leaves
let mut analysis_pending = UniqueQueue::default();

for class in egraph.classes().values() {
for node in &class.nodes {
if egraph[node].is_leaf() {
analysis_pending.insert(node.clone());
}
}
}

// 3. analyse from leaves towards parents until fixpoint
let mut costs = HashMap::<ClassId, CostSet>::default();

while let Some(node_id) = analysis_pending.pop() {
let class_id = n2c(&node_id);
let node = &egraph[&node_id];
if node.children.iter().all(|c| costs.contains_key(n2c(c))) {
let lookup = costs.get(class_id);
let mut prev_cost = INFINITY;
if lookup.is_some() {
prev_cost = lookup.unwrap().total;
}

let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs);
if cost_set.total < prev_cost {
costs.insert(class_id.clone(), cost_set);
analysis_pending.extend(parents[class_id].iter().cloned());
}
} else {
analysis_pending.insert(node_id.clone());
}
}

/*
for class in egraph.classes().values() {
for node in &class.nodes {
Self::check(&egraph, node.clone(), &costs);
}
}
*/

let mut result = ExtractionResult::default();
for (cid, cost_set) in costs {
result.choose(cid, cost_set.choice);
}

result
}
}

/** A data structure to maintain a queue of unique elements.
Notably, insert/pop operations have O(1) expected amortized runtime complexity.
*/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub(crate) struct UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
set: std::collections::HashSet<T>, // hashbrown::
queue: std::collections::VecDeque<T>,
}

impl<T> Default for UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
fn default() -> Self {
UniqueQueue {
set: std::collections::HashSet::default(),
queue: std::collections::VecDeque::new(),
}
}
}

impl<T> UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn insert(&mut self, t: T) {
if self.set.insert(t.clone()) {
self.queue.push_back(t);
}
}

pub fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = T>,
{
for t in iter.into_iter() {
self.insert(t);
}
}

pub fn pop(&mut self) -> Option<T> {
let res = self.queue.pop_front();
res.as_ref().map(|t| self.set.remove(t));
res
}

#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
let r = self.queue.is_empty();
debug_assert_eq!(r, self.set.is_empty());
r
}
}
1 change: 1 addition & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub use crate::*;

pub mod bottom_up;
pub mod greedy_dag;
pub mod greedy_dag_1;

#[cfg(feature = "ilp-cbc")]
pub mod ilp_cbc;
Expand Down
7 changes: 4 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ fn main() {
env_logger::init();

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
),
#[cfg(feature = "ilp-cbc")]
("ilp-cbc", extract::ilp_cbc::CbcExtractor.boxed()),
(
"faster-greedy-dag",
extract::greedy_dag_1::FasterGreedyDagExtractor.boxed(),
)
]
.into_iter()
.collect();
Expand Down

0 comments on commit 4dcc3bc

Please sign in to comment.