Skip to content

Commit

Permalink
Faster greedy sharing-aware extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen committed Sep 13, 2023
1 parent ee68161 commit d6e4e48
Showing 1 changed file with 54 additions and 32 deletions.
86 changes: 54 additions & 32 deletions src/extract/greedy_dag_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,79 @@ impl FasterGreedyDagExtractor {
egraph: &EGraph,
node_id: NodeId,
costs: &HashMap<ClassId, CostSet>,
best_cost: Cost,
) -> CostSet {
let node = &egraph[&node_id];

let cid = egraph.nid_to_cid(&node_id);
// No children -> easy.
if node.children.is_empty() {
return CostSet {
costs: std::collections::HashMap::default(),
total: node.cost,
choice: node_id.clone(),
};
}

let mut desc = 0;
let mut children_cost = Cost::default();
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
let cs = costs.get(child_cid).unwrap();
desc += cs.costs.len();
children_cost += cs.total;
// Get unique classes of children.
let mut childrens_classes = node
.children
.iter()
.map(|c| egraph.nid_to_cid(&c).clone())
.collect::<Vec<ClassId>>();
childrens_classes.sort();
childrens_classes.dedup();

let first_cost = costs.get(&childrens_classes[0]).unwrap();

if childrens_classes.len() == 1 && (node.cost + first_cost.total > best_cost) {
// Shortcut. Can't be cheaper so return junk.
return CostSet {
costs: std::collections::HashMap::default(),
total: INFINITY,
choice: node_id.clone(),
};
}

let mut cost_set = CostSet {
costs: std::collections::HashMap::with_capacity(desc),
total: Cost::default(),
choice: node_id.clone(),
};
// Clone the biggest set and insert the others into it.
let id_of_biggest = childrens_classes
.iter()
.max_by_key(|s| costs.get(s).unwrap().costs.len())
.unwrap();
let mut result = costs.get(&id_of_biggest).unwrap().costs.clone();
for child_cid in &childrens_classes {
if child_cid == id_of_biggest {
continue;
}

for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
cost_set
.costs
.extend(costs.get(child_cid).unwrap().costs.clone());
let next_cost = &costs.get(child_cid).unwrap().costs;
for (key, value) in next_cost.iter() {
result.insert(key.clone(), value.clone());
}
}

let contains = cost_set.costs.contains_key(&cid.clone());
cost_set.costs.insert(cid.clone(), node.cost); // this node.
let cid = egraph.nid_to_cid(&node_id);
let contains = result.contains_key(&cid);
result.insert(cid.clone(), node.cost);

if contains {
cost_set.total = INFINITY;
let result_cost = if contains {
INFINITY
} else {
if cost_set.costs.len() == desc + 1 {
// No extra duplicates are found, so the cost is the current
// nodes cost + the children's cost.
cost_set.total = children_cost + node.cost;
} else {
cost_set.total = cost_set.costs.values().sum();
}
result.values().sum()
};

cost_set
return CostSet {
costs: result,
total: result_cost,
choice: node_id.clone(),
};
}
}

impl FasterGreedyDagExtractor {
fn check(egraph: &EGraph, node_id: NodeId, costs: &HashMap<ClassId, CostSet>) {
let cid = egraph.nid_to_cid(&node_id);
let previous = costs.get(cid).unwrap().total;
let cs = Self::calculate_cost_set(egraph, node_id, costs);
let cs = Self::calculate_cost_set(egraph, node_id, costs, INFINITY);
println!("{} {}", cs.total, previous);
assert!(cs.total >= previous);
}
Expand Down Expand Up @@ -114,7 +136,7 @@ impl Extractor for FasterGreedyDagExtractor {
prev_cost = lookup.unwrap().total;
}

let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs);
let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs, prev_cost);
if cost_set.total < prev_cost {
costs.insert(class_id.clone(), cost_set);
analysis_pending.extend(parents[class_id].iter().cloned());
Expand Down

0 comments on commit d6e4e48

Please sign in to comment.