-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Infer extension deltas for Case, Cfg, Conditional, DataflowBlock, Dfg, TailLoop #1195
Changes from all commits
03522c1
d134f88
e835f56
d63931d
cd45bc9
f3a05c4
bf61478
25027f5
0e90ed6
d36e34f
cd96d78
6aa0193
4c5bf9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,6 @@ use std::collections::VecDeque; | |
use std::iter; | ||
|
||
pub(crate) use self::hugrmut::HugrMut; | ||
use self::validate::ExtensionError; | ||
pub use self::validate::ValidationError; | ||
|
||
pub use ident::{IdentList, InvalidIdentifier}; | ||
|
@@ -25,9 +24,9 @@ use thiserror::Error; | |
|
||
pub use self::views::{HugrView, RootTagged}; | ||
use crate::core::NodeIndex; | ||
use crate::extension::ExtensionRegistry; | ||
use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; | ||
use crate::ops::custom::resolve_extension_ops; | ||
use crate::ops::OpTag; | ||
use crate::ops::{OpTag, OpTrait}; | ||
pub use crate::ops::{OpType, DEFAULT_OPTYPE}; | ||
use crate::{Direction, Node}; | ||
|
||
|
@@ -92,15 +91,80 @@ impl Hugr { | |
self.validate_no_extensions(extension_registry)?; | ||
#[cfg(feature = "extension_inference")] | ||
{ | ||
self.infer_extensions()?; | ||
self.infer_extensions(false)?; | ||
self.validate_extensions()?; | ||
} | ||
Ok(()) | ||
} | ||
|
||
/// Leaving this here as in the future we plan for it to infer deltas | ||
/// of container nodes e.g. [OpType::DFG]. For the moment it does nothing. | ||
pub fn infer_extensions(&mut self) -> Result<(), ExtensionError> { | ||
/// Infers an extension-delta for any non-function container node | ||
/// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta | ||
/// will be the smallest delta compatible with its children and that includes any | ||
/// other [ExtensionId]s in the current delta. | ||
/// | ||
/// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED], | ||
/// ExtensionIds are removed from the delta if they are *not* used by any child node. | ||
/// | ||
/// The non-function container nodes are: | ||
/// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop] | ||
/// | ||
/// [Case]: crate::ops::Case | ||
/// [CFG]: crate::ops::CFG | ||
/// [Conditional]: crate::ops::Conditional | ||
/// [DataflowBlock]: crate::ops::DataflowBlock | ||
/// [DFG]: crate::ops::DFG | ||
/// [TailLoop]: crate::ops::TailLoop | ||
/// [extension_delta]: crate::ops::OpType::extension_delta | ||
/// [ExtensionId]: crate::extension::ExtensionId | ||
pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { | ||
fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { | ||
match optype { | ||
OpType::DFG(dfg) => Some(&mut dfg.signature.extension_reqs), | ||
OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta), | ||
OpType::TailLoop(tl) => Some(&mut tl.extension_delta), | ||
OpType::CFG(cfg) => Some(&mut cfg.signature.extension_reqs), | ||
OpType::Conditional(c) => Some(&mut c.extension_delta), | ||
OpType::Case(c) => Some(&mut c.signature.extension_reqs), | ||
//OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed | ||
//OpType::FuncDefn(_) // Not at present due to the possibility of recursion | ||
_ => None, | ||
} | ||
} | ||
fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result<ExtensionSet, ExtensionError> { | ||
let mut child_sets = h | ||
.children(node) | ||
.collect::<Vec<_>>() // Avoid borrowing h over recursive call | ||
.into_iter() | ||
.map(|ch| Ok((ch, infer(h, ch, remove)?))) | ||
.collect::<Result<Vec<_>, _>>()?; | ||
|
||
let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we just add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, but
I think this is fine for inside |
||
return Ok(h.get_optype(node).extension_delta()); | ||
}; | ||
if es.contains(&TO_BE_INFERRED) { | ||
// Do not remove anything from current delta - any other elements are a lower bound | ||
child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst | ||
} else if remove { | ||
child_sets.iter().try_for_each(|(ch, ch_exts)| { | ||
if !es.is_superset(ch_exts) { | ||
return Err(ExtensionError { | ||
parent: node, | ||
parent_extensions: es.clone(), | ||
child: *ch, | ||
child_extensions: ch_exts.clone(), | ||
}); | ||
} | ||
Ok(()) | ||
})?; | ||
} else { | ||
return Ok(es.clone()); // Can't neither add nor remove, so nothing to do | ||
} | ||
let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); | ||
*es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged); | ||
|
||
Ok(es.clone()) | ||
} | ||
infer(self, self.root(), remove)?; | ||
Ok(()) | ||
} | ||
} | ||
|
@@ -203,6 +267,16 @@ impl Hugr { | |
} | ||
} | ||
|
||
#[derive(Debug, Clone, PartialEq, Error)] | ||
#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] | ||
/// An error in the extension deltas. | ||
pub struct ExtensionError { | ||
parent: Node, | ||
parent_extensions: ExtensionSet, | ||
child: Node, | ||
child_extensions: ExtensionSet, | ||
} | ||
|
||
/// Errors that can occur while manipulating a Hugr. | ||
/// | ||
/// TODO: Better descriptions, not just re-exporting portgraph errors. | ||
|
@@ -221,7 +295,14 @@ pub enum HugrError { | |
|
||
#[cfg(test)] | ||
mod test { | ||
use super::{Hugr, HugrView}; | ||
use super::internal::HugrMutInternals; | ||
#[cfg(feature = "extension_inference")] | ||
use super::ValidationError; | ||
use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; | ||
use crate::extension::{ExtensionId, ExtensionSet, EMPTY_REG, TO_BE_INFERRED}; | ||
use crate::types::{FunctionType, Type}; | ||
use crate::{const_extension_ids, ops, type_row}; | ||
use rstest::rstest; | ||
|
||
#[test] | ||
fn impls_send_and_sync() { | ||
|
@@ -240,4 +321,167 @@ mod test { | |
let hugr = simple_dfg_hugr(); | ||
assert_matches!(hugr.get_io(hugr.root()), Some(_)); | ||
} | ||
|
||
const_extension_ids! { | ||
const XA: ExtensionId = "EXT_A"; | ||
const XB: ExtensionId = "EXT_B"; | ||
} | ||
|
||
#[rstest] | ||
#[case([], XA.into())] | ||
#[case([XA], XA.into())] | ||
#[case([XB], ExtensionSet::from_iter([XA, XB]))] | ||
|
||
fn infer_single_delta( | ||
#[case] parent: impl IntoIterator<Item = ExtensionId>, | ||
#[values(true, false)] remove: bool, // makes no difference when inferring | ||
#[case] result: ExtensionSet, | ||
) { | ||
let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); | ||
let (mut h, _) = build_ext_dfg(parent); | ||
h.infer_extensions(remove).unwrap(); | ||
assert_eq!(h, build_ext_dfg(result).0); | ||
} | ||
|
||
#[test] | ||
fn infer_removes_from_delta() { | ||
let parent = ExtensionSet::from_iter([XA, XB]); | ||
let mut h = build_ext_dfg(parent.clone()).0; | ||
let backup = h.clone(); | ||
h.infer_extensions(false).unwrap(); | ||
assert_eq!(h, backup); // did nothing | ||
h.infer_extensions(true).unwrap(); | ||
assert_eq!(h, build_ext_dfg(XA.into()).0); | ||
} | ||
|
||
#[test] | ||
fn infer_bad_remove() { | ||
let (mut h, mid) = build_ext_dfg(XB.into()); | ||
let backup = h.clone(); | ||
h.infer_extensions(false).unwrap(); | ||
assert_eq!(h, backup); // did nothing | ||
let val_res = h.validate(&EMPTY_REG); | ||
let expected_err = ExtensionError { | ||
parent: h.root(), | ||
parent_extensions: XB.into(), | ||
child: mid, | ||
child_extensions: XA.into(), | ||
}; | ||
#[cfg(feature = "extension_inference")] | ||
assert_eq!( | ||
val_res, | ||
Err(ValidationError::ExtensionError(expected_err.clone())) | ||
); | ||
#[cfg(not(feature = "extension_inference"))] | ||
assert!(val_res.is_ok()); | ||
|
||
let inf_res = h.infer_extensions(true); | ||
assert_eq!(inf_res, Err(expected_err)); | ||
} | ||
|
||
fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) { | ||
let ty = Type::new_function(FunctionType::new_endo(type_row![])); | ||
let mut h = Hugr::new(ops::DFG { | ||
signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()), | ||
}); | ||
let root = h.root(); | ||
let mid = add_inliftout(&mut h, root, ty); | ||
(h, mid) | ||
} | ||
|
||
fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node { | ||
let inp = h.add_node_with_parent( | ||
p, | ||
ops::Input { | ||
types: ty.clone().into(), | ||
}, | ||
); | ||
let out = h.add_node_with_parent( | ||
p, | ||
ops::Output { | ||
types: ty.clone().into(), | ||
}, | ||
); | ||
let mid = h.add_node_with_parent( | ||
p, | ||
ops::Lift { | ||
type_row: ty.into(), | ||
new_extension: XA, | ||
}, | ||
); | ||
h.connect(inp, 0, mid, 0); | ||
h.connect(mid, 0, out, 0); | ||
mid | ||
} | ||
|
||
#[rstest] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think these cases would benefit from comments next to each saying briefly why the result is expected There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair, done |
||
// Base case success: delta inferred for parent equals grandparent. | ||
#[case([XA], [TO_BE_INFERRED], true, [XA])] | ||
// Success: delta inferred for parent is subset of grandparent | ||
#[case([XA, XB], [TO_BE_INFERRED], true, [XA])] | ||
// Base case failure: infers [XA] for parent but grandparent has disjoint set | ||
#[case([XB], [TO_BE_INFERRED], false, [XA])] | ||
// Failure: as previous, but extra "lower bound" on parent that has no effect | ||
#[case([XB], [XA, TO_BE_INFERRED], false, [XA])] | ||
// Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB | ||
#[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] | ||
// Success: grandparent includes extra XB required for parent's "lower bound" | ||
#[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] | ||
// Success: grandparent is also inferred so can include 'extra' XB from parent | ||
#[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] | ||
// No inference: extraneous XB in parent is removed so all become [XA]. | ||
#[case([XA], [XA, XB], true, [XA])] | ||
fn infer_three_generations( | ||
#[case] grandparent: impl IntoIterator<Item = ExtensionId>, | ||
#[case] parent: impl IntoIterator<Item = ExtensionId>, | ||
#[case] success: bool, | ||
#[case] result: impl IntoIterator<Item = ExtensionId>, | ||
) { | ||
let ty = Type::new_function(FunctionType::new_endo(type_row![])); | ||
let grandparent = ExtensionSet::from_iter(grandparent); | ||
let result = ExtensionSet::from_iter(result); | ||
let root_ty = ops::Conditional { | ||
sum_rows: vec![type_row![]], | ||
other_inputs: ty.clone().into(), | ||
outputs: ty.clone().into(), | ||
extension_delta: grandparent.clone(), | ||
}; | ||
let mut h = Hugr::new(root_ty.clone()); | ||
let p = h.add_node_with_parent( | ||
h.root(), | ||
ops::Case { | ||
signature: FunctionType::new_endo(ty.clone()) | ||
.with_extension_delta(ExtensionSet::from_iter(parent)), | ||
}, | ||
); | ||
add_inliftout(&mut h, p, ty.clone()); | ||
assert!(h.validate_extensions().is_err()); | ||
let backup = h.clone(); | ||
let inf_res = h.infer_extensions(true); | ||
if success { | ||
assert!(inf_res.is_ok()); | ||
let expected_p = ops::Case { | ||
signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()), | ||
}; | ||
let mut expected = backup; | ||
expected.replace_op(p, expected_p).unwrap(); | ||
let expected_gp = ops::Conditional { | ||
extension_delta: result, | ||
..root_ty | ||
}; | ||
expected.replace_op(h.root(), expected_gp).unwrap(); | ||
|
||
assert_eq!(h, expected); | ||
} else { | ||
assert_eq!( | ||
inf_res, | ||
Err(ExtensionError { | ||
parent: h.root(), | ||
parent_extensions: grandparent, | ||
child: p, | ||
child_extensions: result | ||
}) | ||
); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker}; | |
use portgraph::{LinkView, PortView}; | ||
use thiserror::Error; | ||
|
||
use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; | ||
use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED}; | ||
|
||
use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; | ||
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; | ||
|
@@ -19,6 +19,7 @@ use crate::types::{EdgeKind, FunctionType}; | |
use crate::{Direction, Hugr, Node, Port}; | ||
|
||
use super::views::{HierarchyView, HugrView, SiblingGraph}; | ||
use super::ExtensionError; | ||
|
||
/// Structure keeping track of pre-computed information used in the validation | ||
/// process. | ||
|
@@ -59,6 +60,9 @@ impl Hugr { | |
pub fn validate_extensions(&self) -> Result<(), ValidationError> { | ||
for parent in self.nodes() { | ||
let parent_op = self.get_optype(parent); | ||
if parent_op.extension_delta().contains(&TO_BE_INFERRED) { | ||
return Err(ValidationError::ExtensionsNotInferred { node: parent }); | ||
} | ||
let parent_extensions = match parent_op.inner_function_type() { | ||
Some(FunctionType { extension_reqs, .. }) => extension_reqs, | ||
None => match parent_op.tag() { | ||
|
@@ -743,6 +747,9 @@ pub enum ValidationError { | |
/// There are errors in the extension deltas. | ||
#[error(transparent)] | ||
ExtensionError(#[from] ExtensionError), | ||
/// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference. | ||
#[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")] | ||
ExtensionsNotInferred { node: Node }, | ||
/// Error in a node signature | ||
#[error("Error in signature of node {node:?}: {cause}")] | ||
SignatureError { node: Node, cause: SignatureError }, | ||
|
@@ -815,15 +822,5 @@ pub enum InterGraphEdgeError { | |
}, | ||
} | ||
|
||
#[derive(Debug, Clone, PartialEq, Error)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. breaking so needs mentioning in PR commit name + description |
||
#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] | ||
/// An error in the extension deltas. | ||
pub struct ExtensionError { | ||
parent: Node, | ||
parent_extensions: ExtensionSet, | ||
child: Node, | ||
child_extensions: ExtensionSet, | ||
} | ||
|
||
#[cfg(test)] | ||
pub(crate) mod test; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this (+tests) might be better in a separate file, but this is 200 lines, a lot shorter than the old
infer.rs