Skip to content

Commit

Permalink
Support custom filtering during graph generation
Browse files Browse the repository at this point in the history
- Listen to keyboard interrupts during graph generation called from Python
- Improve AArch64 ASM by preventing memory reads
  • Loading branch information
benruijl committed Feb 7, 2025
1 parent 728ee05 commit e74596f
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 105 deletions.
72 changes: 60 additions & 12 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use crate::{
CompileOptions, CompiledEvaluator, EvaluationFn, ExpressionEvaluator, FunctionMap,
InlineASM, OptimizationSettings,
},
graph::Graph,
graph::{GenerationSettings, Graph},
id::{
Condition, ConditionResult, Evaluate, Match, MatchSettings, MatchStack, Pattern,
PatternAtomTreeIterator, PatternOrMap, PatternRestriction, Relation, ReplaceIterator,
Expand Down Expand Up @@ -11513,7 +11513,8 @@ impl PythonGraph {
/// of vertex connections.
///
/// Returns the canonical form of the graph and the size of its automorphism group (including edge permutations).
#[pyo3(signature = (external_edges, vertex_signatures, max_vertices = None, max_loops = None, max_bridges = None, allow_self_loops = None))]
#[pyo3(signature = (external_edges, vertex_signatures, max_vertices = None, max_loops = None,
max_bridges = None, allow_self_loops = None, filter_fn = None))]
#[classmethod]
fn generate(
_cls: &Bound<'_, PyType>,
Expand All @@ -11526,6 +11527,7 @@ impl PythonGraph {
max_loops: Option<usize>,
max_bridges: Option<usize>,
allow_self_loops: Option<bool>,
filter_fn: Option<PyObject>,
) -> PyResult<HashMap<PythonGraph, PythonExpression>> {
if max_vertices.is_none() && max_loops.is_none() {
return Err(exceptions::PyValueError::new_err(
Expand All @@ -11546,17 +11548,63 @@ impl PythonGraph {
})
.collect();

Ok(Graph::generate(
&external_edges,
&vertex_signatures,
max_vertices,
max_loops,
max_bridges,
allow_self_loops.unwrap_or(false),
let mut settings = GenerationSettings::new();
if let Some(max_vertices) = max_vertices {
settings = settings.max_vertices(max_vertices);
}

if let Some(max_loops) = max_loops {
settings = settings.max_loops(max_loops);
}

if let Some(max_bridges) = max_bridges {
settings = settings.max_loops(max_bridges);
}

if let Some(allow_self_loops) = allow_self_loops {
settings = settings.allow_self_loops(allow_self_loops);
}

let abort = Arc::new(std::sync::atomic::AtomicBool::new(false));

if let Some(filter_fn) = filter_fn {
let abort = abort.clone();
settings = settings.filter_fn(Box::new(move |g, v| {
Python::with_gil(|py| {
match filter_fn.call(py, (Self { graph: g.clone() }, v), None) {
Ok(r) => r
.extract::<bool>(py)
.expect("Match map does not return a boolean"),
Err(e) => {
if e.is_instance_of::<exceptions::PyKeyboardInterrupt>(py) {
abort.store(true, std::sync::atomic::Ordering::Relaxed);
false
} else {
panic!("Bad callback function: {}", e);
}
}
}
})
}));
}

settings = settings.abort_check(Box::new(move || {
if abort.load(std::sync::atomic::Ordering::Relaxed) {
true
} else {
Python::with_gil(|py| py.check_signals())
.map(|_| false)
.unwrap_or(true)
}
}));

Ok(
Graph::generate(&external_edges, &vertex_signatures, &settings)
.unwrap_or_else(|e| e)
.into_iter()
.map(|(k, v)| (Self { graph: k }, Atom::new_num(v).into()))
.collect(),
)
.into_iter()
.map(|(k, v)| (Self { graph: k }, Atom::new_num(v).into()))
.collect())
}

/// Convert the graph to a graphviz dot string.
Expand Down
71 changes: 34 additions & 37 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,8 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
asm_flavour: InlineASM,
out: &mut String,
) -> bool {
let mut second_index = 0;

macro_rules! get_input {
($i:expr) => {
if $i < self.param_count {
Expand All @@ -1400,7 +1402,14 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
($i:expr) => {
match asm_flavour {
InlineASM::X64 => {
format_addr!($i)
if $i < self.param_count {
format!("{}(%2)", $i * 8)
} else if $i < self.reserved_indices {
format!("{}(%1)", ($i - self.param_count) * 8)
} else {
// TODO: subtract reserved indices
format!("{}(%0)", $i * 8)
}
}
InlineASM::AArch64 => {
if $i < self.param_count {
Expand All @@ -1412,6 +1421,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
let shift = d.min(12);
let coeff = dest / (1 << shift);
let rest = dest - (coeff << shift);
second_index = 0;
*out += &format!(
"\t\t\"add x8, %2, {}, lsl {}\\n\\t\"\n",
coeff, shift
Expand All @@ -1427,6 +1437,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
let shift = d.min(12);
let coeff = dest / (1 << shift);
let rest = dest - (coeff << shift);
second_index = 0;
*out += &format!(
"\t\t\"add x8, %1, {}, lsl {}\\n\\t\"\n",
coeff, shift
Expand All @@ -1438,18 +1449,23 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
} else {
// TODO: subtract reserved indices
let dest = $i * 8;
if dest > 32760 {
if dest > 32760 && (dest < second_index || dest > 32760 + second_index)
{
let d = dest.ilog2();
let shift = d.min(12);
let coeff = dest / (1 << shift);
let rest = dest - (coeff << shift);
second_index = coeff << shift;
let rest = dest - second_index;
*out += &format!(
"\t\t\"add x8, %0, {}, lsl {}\\n\\t\"\n",
coeff, shift
);
format!("[x8, {}]", rest)
} else {
} else if dest <= 32760 {
format!("[%0, {}]", dest)
} else {
let offset = dest - second_index;
format!("[x8, {}]", offset)
}
}
}
Expand All @@ -1458,34 +1474,6 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
};
}

macro_rules! format_addr {
($i:expr) => {
match asm_flavour {
InlineASM::X64 => {
if $i < self.param_count {
format!("{}(%2)", $i * 8)
} else if $i < self.reserved_indices {
format!("{}(%1)", ($i - self.param_count) * 8)
} else {
// TODO: subtract reserved indices
format!("{}(%0)", $i * 8)
}
}
InlineASM::AArch64 => {
if $i < self.param_count {
format!("[%2, {}]", $i * 8)
} else if $i < self.reserved_indices {
format!("[%1, {}]", ($i - self.param_count) * 8)
} else {
// TODO: subtract reserved indices
format!("[%0, {}]", $i * 8)
}
}
InlineASM::None => unreachable!(),
}
};
}

macro_rules! end_asm_block {
($in_block: expr) => {
if $in_block {
Expand Down Expand Up @@ -2283,6 +2271,8 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
asm_flavour: InlineASM,
out: &mut String,
) -> bool {
let mut second_index = 0;

macro_rules! get_input {
($i:expr) => {
if $i < self.param_count {
Expand All @@ -2305,15 +2295,15 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
match asm_flavour {
InlineASM::X64 => {
if $i < self.param_count {
(format!("{}(%2)", $i * 16), "NA".to_owned())
(format!("{}(%2)", $i * 16), String::new())
} else if $i < self.reserved_indices {
(
format!("{}(%1)", ($i - self.param_count) * 16),
"NA".to_owned(),
)
} else {
// TODO: subtract reserved indices
(format!("{}(%0)", $i * 16), "NA".to_owned())
(format!("{}(%0)", $i * 16), String::new())
}
}
InlineASM::AArch64 => {
Expand All @@ -2326,6 +2316,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
let shift = d.min(12);
let coeff = dest / (1 << shift);
let rest = dest - (coeff << shift);
second_index = 0;
*out += &format!(
"\t\t\"add x8, %2, {}, lsl {}\\n\\t\"\n",
coeff, shift
Expand All @@ -2341,6 +2332,7 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
let shift = d.min(12);
let coeff = dest / (1 << shift);
let rest = dest - (coeff << shift);
second_index = 0;
*out += &format!(
"\t\t\"add x8, %1, {}, lsl {}\\n\\t\"\n",
coeff, shift
Expand All @@ -2352,18 +2344,23 @@ impl<T: std::fmt::Display> ExpressionEvaluator<T> {
} else {
// TODO: subtract reserved indices
let dest = $i * 16;
if dest > 32760 {
if dest > 32760 && (dest < second_index || dest > 32760 + second_index)
{
let d = dest.ilog2();
let shift = d.min(12);
let coeff = dest / (1 << shift);
let rest = dest - (coeff << shift);
second_index = coeff << shift;
let rest = dest - second_index;
*out += &format!(
"\t\t\"add x8, %0, {}, lsl {}\\n\\t\"\n",
coeff, shift
);
(format!("[x8, {}]", rest), format!("[x8, {}]", rest + 8))
} else {
} else if dest <= 32760 {
(format!("[%0, {}]", dest), format!("[%0, {}]", dest + 8))
} else {
let offset = dest - second_index;
(format!("[x8, {}]", offset), format!("[x8, {}]", offset + 8))
}
}
}
Expand Down
Loading

0 comments on commit e74596f

Please sign in to comment.