diff --git a/Cargo.lock b/Cargo.lock index c75ed60..3c43c60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "bitmaps" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031043d04099746d8db04daf1fa424b2bc8bd69d92b25962dcde24da39ab64a2" +dependencies = [ + "typenum", +] + [[package]] name = "coin_cbc" version = "0.1.7" @@ -68,6 +77,7 @@ dependencies = [ "coin_cbc", "egraph-serialize", "env_logger", + "im-rc", "indexmap", "log", "ordered-float", @@ -80,6 +90,20 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +[[package]] +name = "im-rc" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1955a75fa080c677d3972822ec4bad316169ab1cfc6c257a942c2265dbe5fe" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "indexmap" version = "2.0.0" @@ -184,6 +208,15 @@ dependencies = [ "serde", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "ryu" version = "1.0.14" @@ -222,6 +255,16 @@ dependencies = [ "serde", ] +[[package]] +name = "sized-chunks" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d69225bde7a69b235da73377861095455d298f2b970996eec25ddbb42b3d1e" +dependencies = [ + "bitmaps", + "typenum", +] + [[package]] name = "syn" version = "2.0.23" @@ -233,8 +276,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + [[package]] name = "unicode-ident" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22049a19f4a68748a168c0fc439f9516686aa045927ff767eca0a85101fb6e73" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" diff --git a/Cargo.toml b/Cargo.toml index 3f5a3cf..b3c5aa2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ pico-args = {version = "0.5.0", features = ["eq-separator"]} anyhow = "1.0.71" coin_cbc = {version = "0.1.6", optional = true} +im-rc = "15.1.0" [dependencies.egraph-serialize] git = "https://github.com/egraphs-good/egraph-serialize" diff --git a/Makefile b/Makefile index 1eb5fd0..744e6e5 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ all: test nits bench define run-extraction TARGETS += $(1:data/%=output/%)-$(2).json $(1:data/%=output/%)-$(2).json: $(1) - mkdir -p $$(dir $$@) + @mkdir -p $$(dir $$@) $(PROGRAM) $$< --extractor=$(2) --out=$$@ endef @@ -28,7 +28,7 @@ $(foreach ext,$(EXTRACTORS),\ .PHONY: bench bench: plot.py $(TARGETS) - ./$< + ./$^ $(PROGRAM): $(SRC) cargo build $(FLAGS) diff --git a/plot.py b/plot.py index 6c52093..4cfa48f 100755 --- a/plot.py +++ b/plot.py @@ -47,15 +47,29 @@ def process(js, extractors=[]): print(f"Error processing {name}") raise e + print(f"{e1} / {e2}") + + print("geo mean") tree_summary = statistics.geometric_mean(s["tree"] for s in summaries.values()) dag_summary = statistics.geometric_mean(s["dag"] for s in summaries.values()) micros_summary = statistics.geometric_mean(s["micros"] for s in summaries.values()) - print(f"{e1} / {e2}") print(f"tree: {tree_summary:.4f}") print(f"dag: {dag_summary:.4f}") print(f"micros: {micros_summary:.4f}") + print("quantiles") + + def quantiles(key): + xs = [s[key] for s in summaries.values()] + qs = statistics.quantiles(xs, n=4) + with_extremes = [min(xs)] + qs + [max(xs)] + return ", ".join(f"{x:.4f}" for x in with_extremes) + + print(f"tree: {quantiles('tree')}") + print(f"dag: {quantiles('dag')}") + print(f"micros: {quantiles('micros')}") + if __name__ == "__main__": print() diff --git a/src/extract/greedy_dag.rs b/src/extract/greedy_dag.rs new file mode 100644 index 0000000..6f62b03 --- /dev/null +++ b/src/extract/greedy_dag.rs @@ -0,0 +1,79 @@ +use super::*; + +struct CostSet { + costs: im_rc::HashMap, + total: Cost, + choice: NodeId, +} + +pub struct GreedyDagExtractor; +impl Extractor for GreedyDagExtractor { + fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult { + let mut costs = IndexMap::::default(); + let mut keep_going = true; + + let mut nodes = egraph.nodes.clone(); + + let mut i = 0; + while keep_going { + i += 1; + println!("iteration {}", i); + keep_going = false; + + let mut to_remove = vec![]; + + 'node_loop: for (node_id, node) in &nodes { + let cid = egraph.nid_to_cid(node_id); + let mut cost_set = CostSet { + costs: im_rc::HashMap::new(), + total: Cost::default(), + choice: node_id.clone(), + }; + + // compute the cost set from the children + for child in &node.children { + let child_cid = egraph.nid_to_cid(child); + if let Some(child_cost_set) = costs.get(child_cid) { + // prevent a cycle + if child_cost_set.costs.contains_key(cid) { + continue 'node_loop; + } + cost_set.costs.extend(child_cost_set.costs.clone()); + } else { + continue 'node_loop; + } + } + + // add this node + cost_set.costs.insert(cid.clone(), node.cost); + + cost_set.total = cost_set.costs.values().sum(); + + // if the cost set is better than the current one, update it + if let Some(old_cost_set) = costs.get(cid) { + if cost_set.total < old_cost_set.total { + costs.insert(cid.clone(), cost_set); + keep_going = true; + } + } else { + costs.insert(cid.clone(), cost_set); + keep_going = true; + } + to_remove.push(node_id.clone()); + } + + // removing nodes you've "done" can speed it up a lot but makes the results much worse + if false { + for node_id in to_remove { + nodes.remove(&node_id); + } + } + } + + let mut result = ExtractionResult::default(); + for (cid, cost_set) in costs { + result.choose(cid, cost_set.choice); + } + result + } +} diff --git a/src/extract/mod.rs b/src/extract/mod.rs index f273f9a..8cb01a0 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; pub use crate::*; pub mod bottom_up; +pub mod greedy_dag; #[cfg(feature = "ilp-cbc")] pub mod ilp_cbc; diff --git a/src/main.rs b/src/main.rs index 5bf3622..d2a488d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,10 @@ fn main() { let extractors: IndexMap<&str, Box> = [ ("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()), + ( + "greedy-dag", + extract::greedy_dag::GreedyDagExtractor.boxed(), + ), #[cfg(feature = "ilp-cbc")] ("ilp-cbc", extract::ilp_cbc::CbcExtractor.boxed()), ] @@ -66,6 +70,9 @@ fn main() { let result = extractor.extract(&egraph, &egraph.root_eclasses); let us = start_time.elapsed().as_micros(); + assert!(result + .find_cycles(&egraph, &egraph.root_eclasses) + .is_empty()); let tree = result.tree_cost(&egraph, &egraph.root_eclasses); let dag = result.dag_cost(&egraph, &egraph.root_eclasses);