Skip to content

Commit

Permalink
Add a simple greedy dag
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Jul 8, 2023
1 parent 5902665 commit 95e8a74
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 3 deletions.
55 changes: 55 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pico-args = {version = "0.5.0", features = ["eq-separator"]}

anyhow = "1.0.71"
coin_cbc = {version = "0.1.6", optional = true}
im-rc = "15.1.0"

[dependencies.egraph-serialize]
git = "https://github.com/egraphs-good/egraph-serialize"
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ all: test nits bench
define run-extraction
TARGETS += $(1:data/%=output/%)-$(2).json
$(1:data/%=output/%)-$(2).json: $(1)
mkdir -p $$(dir $$@)
@mkdir -p $$(dir $$@)
$(PROGRAM) $$< --extractor=$(2) --out=$$@
endef

Expand All @@ -28,7 +28,7 @@ $(foreach ext,$(EXTRACTORS),\

.PHONY: bench
bench: plot.py $(TARGETS)
./$<
./$^

$(PROGRAM): $(SRC)
cargo build $(FLAGS)
Expand Down
16 changes: 15 additions & 1 deletion plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,29 @@ def process(js, extractors=[]):
print(f"Error processing {name}")
raise e

print(f"{e1} / {e2}")

print("geo mean")
tree_summary = statistics.geometric_mean(s["tree"] for s in summaries.values())
dag_summary = statistics.geometric_mean(s["dag"] for s in summaries.values())
micros_summary = statistics.geometric_mean(s["micros"] for s in summaries.values())

print(f"{e1} / {e2}")
print(f"tree: {tree_summary:.4f}")
print(f"dag: {dag_summary:.4f}")
print(f"micros: {micros_summary:.4f}")

print("quantiles")

def quantiles(key):
xs = [s[key] for s in summaries.values()]
qs = statistics.quantiles(xs, n=4)
with_extremes = [min(xs)] + qs + [max(xs)]
return ", ".join(f"{x:.4f}" for x in with_extremes)

print(f"tree: {quantiles('tree')}")
print(f"dag: {quantiles('dag')}")
print(f"micros: {quantiles('micros')}")


if __name__ == "__main__":
print()
Expand Down
79 changes: 79 additions & 0 deletions src/extract/greedy_dag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use super::*;

struct CostSet {
costs: im_rc::HashMap<ClassId, Cost>,
total: Cost,
choice: NodeId,
}

pub struct GreedyDagExtractor;
impl Extractor for GreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut costs = IndexMap::<ClassId, CostSet>::default();
let mut keep_going = true;

let mut nodes = egraph.nodes.clone();

let mut i = 0;
while keep_going {
i += 1;
println!("iteration {}", i);
keep_going = false;

let mut to_remove = vec![];

'node_loop: for (node_id, node) in &nodes {
let cid = egraph.nid_to_cid(node_id);
let mut cost_set = CostSet {
costs: im_rc::HashMap::new(),
total: Cost::default(),
choice: node_id.clone(),
};

// compute the cost set from the children
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
if let Some(child_cost_set) = costs.get(child_cid) {
// prevent a cycle
if child_cost_set.costs.contains_key(cid) {
continue 'node_loop;
}
cost_set.costs.extend(child_cost_set.costs.clone());
} else {
continue 'node_loop;
}
}

// add this node
cost_set.costs.insert(cid.clone(), node.cost);

cost_set.total = cost_set.costs.values().sum();

// if the cost set is better than the current one, update it
if let Some(old_cost_set) = costs.get(cid) {
if cost_set.total < old_cost_set.total {
costs.insert(cid.clone(), cost_set);
keep_going = true;
}
} else {
costs.insert(cid.clone(), cost_set);
keep_going = true;
}
to_remove.push(node_id.clone());
}

// removing nodes you've "done" can speed it up a lot but makes the results much worse
if false {
for node_id in to_remove {
nodes.remove(&node_id);
}
}
}

let mut result = ExtractionResult::default();
for (cid, cost_set) in costs {
result.choose(cid, cost_set.choice);
}
result
}
}
1 change: 1 addition & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
pub use crate::*;

pub mod bottom_up;
pub mod greedy_dag;

#[cfg(feature = "ilp-cbc")]
pub mod ilp_cbc;
Expand Down
7 changes: 7 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ fn main() {

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
),
#[cfg(feature = "ilp-cbc")]
("ilp-cbc", extract::ilp_cbc::CbcExtractor.boxed()),
]
Expand Down Expand Up @@ -66,6 +70,9 @@ fn main() {
let result = extractor.extract(&egraph, &egraph.root_eclasses);

let us = start_time.elapsed().as_micros();
assert!(result
.find_cycles(&egraph, &egraph.root_eclasses)
.is_empty());
let tree = result.tree_cost(&egraph, &egraph.root_eclasses);
let dag = result.dag_cost(&egraph, &egraph.root_eclasses);

Expand Down

0 comments on commit 95e8a74

Please sign in to comment.