diff --git a/src/extract/greedy_dag_1.rs b/src/extract/greedy_dag_1.rs index 254d51a..a328e43 100644 --- a/src/extract/greedy_dag_1.rs +++ b/src/extract/greedy_dag_1.rs @@ -17,49 +17,71 @@ impl FasterGreedyDagExtractor { egraph: &EGraph, node_id: NodeId, costs: &HashMap, + best_cost: Cost, ) -> CostSet { let node = &egraph[&node_id]; - let cid = egraph.nid_to_cid(&node_id); + // No children -> easy. + if node.children.is_empty() { + return CostSet { + costs: std::collections::HashMap::default(), + total: node.cost, + choice: node_id.clone(), + }; + } - let mut desc = 0; - let mut children_cost = Cost::default(); - for child in &node.children { - let child_cid = egraph.nid_to_cid(child); - let cs = costs.get(child_cid).unwrap(); - desc += cs.costs.len(); - children_cost += cs.total; + // Get unique classes of children. + let mut childrens_classes = node + .children + .iter() + .map(|c| egraph.nid_to_cid(&c).clone()) + .collect::>(); + childrens_classes.sort(); + childrens_classes.dedup(); + + let first_cost = costs.get(&childrens_classes[0]).unwrap(); + + if childrens_classes.len() == 1 && (node.cost + first_cost.total > best_cost) { + // Shortcut. Can't be cheaper so return junk. + return CostSet { + costs: std::collections::HashMap::default(), + total: INFINITY, + choice: node_id.clone(), + }; } - let mut cost_set = CostSet { - costs: std::collections::HashMap::with_capacity(desc), - total: Cost::default(), - choice: node_id.clone(), - }; + // Clone the biggest set and insert the others into it. + let id_of_biggest = childrens_classes + .iter() + .max_by_key(|s| costs.get(s).unwrap().costs.len()) + .unwrap(); + let mut result = costs.get(&id_of_biggest).unwrap().costs.clone(); + for child_cid in &childrens_classes { + if child_cid == id_of_biggest { + continue; + } - for child in &node.children { - let child_cid = egraph.nid_to_cid(child); - cost_set - .costs - .extend(costs.get(child_cid).unwrap().costs.clone()); + let next_cost = &costs.get(child_cid).unwrap().costs; + for (key, value) in next_cost.iter() { + result.insert(key.clone(), value.clone()); + } } - let contains = cost_set.costs.contains_key(&cid.clone()); - cost_set.costs.insert(cid.clone(), node.cost); // this node. + let cid = egraph.nid_to_cid(&node_id); + let contains = result.contains_key(&cid); + result.insert(cid.clone(), node.cost); - if contains { - cost_set.total = INFINITY; + let result_cost = if contains { + INFINITY } else { - if cost_set.costs.len() == desc + 1 { - // No extra duplicates are found, so the cost is the current - // nodes cost + the children's cost. - cost_set.total = children_cost + node.cost; - } else { - cost_set.total = cost_set.costs.values().sum(); - } + result.values().sum() }; - cost_set + return CostSet { + costs: result, + total: result_cost, + choice: node_id.clone(), + }; } } @@ -67,7 +89,7 @@ impl FasterGreedyDagExtractor { fn check(egraph: &EGraph, node_id: NodeId, costs: &HashMap) { let cid = egraph.nid_to_cid(&node_id); let previous = costs.get(cid).unwrap().total; - let cs = Self::calculate_cost_set(egraph, node_id, costs); + let cs = Self::calculate_cost_set(egraph, node_id, costs, INFINITY); println!("{} {}", cs.total, previous); assert!(cs.total >= previous); } @@ -114,7 +136,7 @@ impl Extractor for FasterGreedyDagExtractor { prev_cost = lookup.unwrap().total; } - let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs); + let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs, prev_cost); if cost_set.total < prev_cost { costs.insert(class_id.clone(), cost_set); analysis_pending.extend(parents[class_id].iter().cloned());