Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into om4
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorHansen committed Feb 7, 2024
2 parents bc0619a + 2a38817 commit b2b6e2e
Show file tree
Hide file tree
Showing 57 changed files with 322 additions and 93 deletions.
80 changes: 0 additions & 80 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,6 @@ pub mod ilp_cbc;
// Allowance for floating point values to be considered equal
pub const EPSILON_ALLOWANCE: f64 = 0.00001;

// I want this to write to a tempfs file system, you'll
// want to change the path in test_save_path to something
// that works for you.
pub const ELABORATE_TESTING: bool = false;

pub fn test_save_path(name: &str) -> String {
return if ELABORATE_TESTING {
format!("/dev/shm/{}_egraph.json", name)
} else {
"".to_string()
};
}

pub trait Extractor: Sync {
fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult;

Expand Down Expand Up @@ -216,70 +203,3 @@ impl ExtractionResult {
.sum::<Cost>()
}
}

use ordered_float::NotNan;
use rand::Rng;

// generates a float between 0 and 1
fn generate_random_not_nan() -> NotNan<f64> {
let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
let random_float: f64 = rng.gen();
NotNan::new(random_float).unwrap()
}

//make a random egraph
pub fn generate_random_egraph() -> EGraph {
let mut rng = rand::thread_rng();
let mut egraph = EGraph::default();
let mut nodes = Vec::<Node>::default();
let mut eclass = 0;

let mut n2nid = IndexMap::<Node, NodeId>::default();
let mut count = 0;

for _ in 0..rng.gen_range(1..100) {
let mut children = Vec::<NodeId>::default();
for node in &nodes {
if rng.gen_bool(0.1) {
children.push(n2nid.get(node).unwrap().clone());
}
}

if rng.gen_bool(0.2) {
eclass += 1;
}

let node = Node {
op: "operation".to_string(),
children: children,
eclass: eclass.to_string().clone().into(),
cost: (generate_random_not_nan() * 100.0),
};

nodes.push(node.clone());
let id = "node_".to_owned() + &count.to_string();
count += 1;
egraph.add_node(id.clone(), node.clone());
n2nid.insert(node.clone(), id.clone().into());
}

//I've not seen this generate an infeasible egraph, and don't undertand why.
let len = nodes.len();
for n in &mut nodes {
if rng.gen_bool(0.5) {
n.children.push(n2nid[rng.gen_range(0..len)].clone());
}
}

// Get roots, potentially duplicate.
for _ in 1..rng.gen_range(2..11) {
egraph.root_eclasses.push(
nodes
.get(rng.gen_range(0..nodes.len()))
.unwrap()
.eclass
.clone(),
);
}
egraph
}
93 changes: 80 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,100 @@ use std::path::PathBuf;
pub type Cost = NotNan<f64>;
pub const INFINITY: Cost = unsafe { NotNan::new_unchecked(std::f64::INFINITY) };

fn main() {
env_logger::init();
#[derive(PartialEq, Eq)]
enum Optimal {
Tree,
DAG,
Neither,
}

struct ExtractorDetail {
extractor: Box<dyn Extractor>,
optimal: Optimal,
use_for_bench: bool,
}

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
fn extractors() -> IndexMap<&'static str, ExtractorDetail> {
let extractors: IndexMap<&'static str, ExtractorDetail> = [
(
"faster-bottom-up",
extract::faster_bottom_up::FasterBottomUpExtractor.boxed(),
"bottom-up",
ExtractorDetail {
extractor: extract::bottom_up::BottomUpExtractor.boxed(),
optimal: Optimal::Tree,
use_for_bench: true,
},
),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
"faster-bottom-up",
ExtractorDetail {
extractor: extract::faster_bottom_up::FasterBottomUpExtractor.boxed(),
optimal: Optimal::Tree,
use_for_bench: true,
},
),
(
"faster-greedy-dag",
extract::faster_greedy_dag::FasterGreedyDagExtractor.boxed(),
ExtractorDetail {
extractor: extract::faster_greedy_dag::FasterGreedyDagExtractor.boxed(),
optimal: Optimal::Neither,
use_for_bench: true,
},
),
/*(
"global-greedy-dag",
ExtractorDetail {
extractor: extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(),
optimal: Optimal::Neither,
use_for_bench: true,
},
),*/
#[cfg(feature = "ilp-cbc")]
(
"ilp-cbc-timeout",
extract::ilp_cbc::CbcExtractorWithTimeout::<10>.boxed(),
ExtractorDetail {
extractor: extract::ilp_cbc::CbcExtractorWithTimeout::<10>.boxed(),
optimal: Optimal::DAG,
use_for_bench: true,
},
),
#[cfg(feature = "ilp-cbc")]
(
"ilp-cbc",
ExtractorDetail {
extractor: extract::ilp_cbc::CbcExtractor.boxed(),
optimal: Optimal::DAG,
use_for_bench: false, // takes >10 hours sometimes
},
),
#[cfg(feature = "ilp-cbc")]
(
"faster-ilp-cbc-timeout",
extract::faster_ilp_cbc::FasterCbcExtractorWithTimeout::<10>.boxed(),
ExtractorDetail {
extractor: extract::faster_ilp_cbc::FasterCbcExtractorWithTimeout::<10>.boxed(),
optimal: Optimal::DAG,
use_for_bench: true,
},
),
#[cfg(feature = "ilp-cbc")]
(
"faster-ilp-cbc",
ExtractorDetail {
extractor: extract::faster_ilp_cbc::FasterCbcExtractor::<10>.boxed(),
optimal: Optimal::DAG,
use_for_bench: true,
},
),
]
.into_iter()
.collect();
return extractors;
}

fn main() {
env_logger::init();

let mut extractors = extractors();
extractors.retain(|_, ed| ed.use_for_bench);

let mut args = pico_args::Arguments::from_env();

Expand Down Expand Up @@ -77,13 +141,13 @@ fn main() {
.with_context(|| format!("Failed to parse {filename}"))
.unwrap();

let extractor = extractors
let ed = extractors
.get(extractor_name.as_str())
.with_context(|| format!("Unknown extractor: {extractor_name}"))
.unwrap();

let start_time = std::time::Instant::now();
let result = extractor.extract(&egraph, &egraph.root_eclasses);
let result = ed.extractor.extract(&egraph, &egraph.root_eclasses);
let us = start_time.elapsed().as_micros();

result.check(&egraph);
Expand All @@ -104,3 +168,6 @@ fn main() {
)
.unwrap();
}

#[cfg(test)]
mod test;
Loading

0 comments on commit b2b6e2e

Please sign in to comment.