From 03522c127ad89990b876f7bfadd8f807bf7167c9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Jun 2024 15:31:08 +0100 Subject: [PATCH 01/13] BREAKING: Move ExtensionError into hugr-core/src/hugr/ --- hugr-core/src/hugr.rs | 13 +++++++++++-- hugr-core/src/hugr/validate.rs | 13 ++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 25aa28214..e35dc5254 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -13,7 +13,6 @@ use std::collections::VecDeque; use std::iter; pub(crate) use self::hugrmut::HugrMut; -use self::validate::ExtensionError; pub use self::validate::ValidationError; pub use ident::{IdentList, InvalidIdentifier}; @@ -25,7 +24,7 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; -use crate::extension::ExtensionRegistry; +use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::custom::resolve_extension_ops; use crate::ops::OpTag; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; @@ -203,6 +202,16 @@ impl Hugr { } } +#[derive(Debug, Clone, PartialEq, Error)] +#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] +/// An error in the extension deltas. +pub struct ExtensionError { + parent: Node, + parent_extensions: ExtensionSet, + child: Node, + child_extensions: ExtensionSet, +} + /// Errors that can occur while manipulating a Hugr. /// /// TODO: Better descriptions, not just re-exporting portgraph errors. diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index be2e06003..bf2f9fa2d 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::extension::{ExtensionRegistry, SignatureError}; use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; @@ -19,6 +19,7 @@ use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, Hugr, Node, Port}; use super::views::{HierarchyView, HugrView, SiblingGraph}; +use super::ExtensionError; /// Structure keeping track of pre-computed information used in the validation /// process. @@ -815,15 +816,5 @@ pub enum InterGraphEdgeError { }, } -#[derive(Debug, Clone, PartialEq, Error)] -#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] -/// An error in the extension deltas. -pub struct ExtensionError { - parent: Node, - parent_extensions: ExtensionSet, - child: Node, - child_extensions: ExtensionSet, -} - #[cfg(test)] pub(crate) mod test; From d134f88edba5a73e5aee2a1b5c8755379959aec7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Jun 2024 21:57:52 +0100 Subject: [PATCH 02/13] Add ExtensionSet::TO_BE_INFERRED, and check it's gone in validation --- hugr-core/src/extension.rs | 2 ++ hugr-core/src/hugr/validate.rs | 11 ++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index deecf1cbb..260e33b9d 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -406,6 +406,8 @@ pub enum ExtensionBuildError { pub struct ExtensionSet(BTreeSet); impl ExtensionSet { + pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked("TO_BE_INFERRED"); + /// Creates a new empty extension set. pub const fn new() -> Self { Self(BTreeSet::new()) diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index bf2f9fa2d..45099fff5 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{ExtensionRegistry, SignatureError}; +use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; @@ -60,6 +60,12 @@ impl Hugr { pub fn validate_extensions(&self) -> Result<(), ValidationError> { for parent in self.nodes() { let parent_op = self.get_optype(parent); + if parent_op + .extension_delta() + .contains(&ExtensionSet::TO_BE_INFERRED) + { + return Err(ValidationError::ExtensionsNotInferred { node: parent }); + } let parent_extensions = match parent_op.inner_function_type() { Some(FunctionType { extension_reqs, .. }) => extension_reqs, None => match parent_op.tag() { @@ -744,6 +750,9 @@ pub enum ValidationError { /// There are errors in the extension deltas. #[error(transparent)] ExtensionError(#[from] ExtensionError), + /// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference... + #[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for TailLoop/Conditional/CFG/DFG/BasicBlock only")] + ExtensionsNotInferred { node: Node }, /// Error in a node signature #[error("Error in signature of node {node:?}: {cause}")] SignatureError { node: Node, cause: SignatureError }, From e835f56160aaf8dc8444eadf60b30301a6bcfc90 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 12 Jun 2024 22:00:09 +0100 Subject: [PATCH 03/13] Inference algorithm --- hugr-core/src/hugr.rs | 61 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index e35dc5254..7b61ca6f4 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -26,7 +26,7 @@ pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::custom::resolve_extension_ops; -use crate::ops::OpTag; +use crate::ops::{OpTag, OpTrait}; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -91,7 +91,7 @@ impl Hugr { self.validate_no_extensions(extension_registry)?; #[cfg(feature = "extension_inference")] { - self.infer_extensions()?; + self.infer_extensions(false)?; self.validate_extensions()?; } Ok(()) @@ -99,9 +99,64 @@ impl Hugr { /// Leaving this here as in the future we plan for it to infer deltas /// of container nodes e.g. [OpType::DFG]. For the moment it does nothing. - pub fn infer_extensions(&mut self) -> Result<(), ExtensionError> { + pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { + fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { + match optype { + OpType::DFG(dfg) => Some(&mut dfg.signature.extension_reqs), + OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta), + OpType::TailLoop(tl) => Some(&mut tl.extension_delta), + OpType::CFG(cfg) => Some(&mut cfg.signature.extension_reqs), + OpType::Conditional(c) => Some(&mut c.extension_delta), + OpType::Case(c) => Some(&mut c.signature.extension_reqs), + //OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed + //OpType::FuncDefn(_) // Not at present due to the possibility of recursion + _ => None, + } + } + fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result { + let child_sets = h + .children(node) + .collect::>() // Avoid borrowing h over recursive call + .into_iter() + .map(|ch| Ok((ch, infer(h, ch, remove)?))) + .collect::, _>>()?; + + let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { + return Ok(h.get_optype(node).extension_delta()); + }; + if !es.contains(&ExtensionSet::TO_BE_INFERRED) { + // Can't add any new extensions... + if !remove { + return Ok(es.clone()); // Can't remove either, so nothing to do + } + child_sets.iter().try_for_each(|(ch, ch_exts)| { + if !es.is_superset(ch_exts) { + return Err(ExtensionError { + parent: node, + parent_extensions: es.clone(), + child: *ch, + child_extensions: ch_exts.clone(), + }); + } + Ok(()) + })?; + }; + let ch_d = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); + let merged = if remove { ch_d } else { ch_d.union(es.clone()) }; + *es = ExtensionSet::singleton(&ExtensionSet::TO_BE_INFERRED).missing_from(&merged); + + Ok(es.clone()) + } + infer(self, self.root(), remove)?; Ok(()) } + + // Note: tests + // * all combinations of (remove or not, TO_BE_INFERRED present or absent, success(inferred-set) or failure (possible only if no TO_BE_INFERRED) ) + // * parent - child - grandchild tests: + // X - Y + INFER - X (ok with remove, but fails w/out remove) + // X - Y + INFER - Y or X - INFER - Y (mid fails against parent with just Y, regardless of remove) + // X - INFER - X (ok with-or-without remove) } /// Internal API for HUGRs, not intended for use by users. From d63931d1134e23a74a7f3f638e58b070d8084883 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 13 Jun 2024 10:12:27 +0100 Subject: [PATCH 04/13] If inferring, never remove --- hugr-core/src/hugr.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 7b61ca6f4..8d1f47d9d 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -114,7 +114,7 @@ impl Hugr { } } fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result { - let child_sets = h + let mut child_sets = h .children(node) .collect::>() // Avoid borrowing h over recursive call .into_iter() @@ -124,11 +124,10 @@ impl Hugr { let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { return Ok(h.get_optype(node).extension_delta()); }; - if !es.contains(&ExtensionSet::TO_BE_INFERRED) { - // Can't add any new extensions... - if !remove { - return Ok(es.clone()); // Can't remove either, so nothing to do - } + if es.contains(&ExtensionSet::TO_BE_INFERRED) { + // Do not remove anything from current delta - any other elements are a lower bound + child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst + } else if remove { child_sets.iter().try_for_each(|(ch, ch_exts)| { if !es.is_superset(ch_exts) { return Err(ExtensionError { @@ -140,9 +139,10 @@ impl Hugr { } Ok(()) })?; - }; - let ch_d = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); - let merged = if remove { ch_d } else { ch_d.union(es.clone()) }; + } else { + return Ok(es.clone()); // Can't neither add nor remove, so nothing to do + } + let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); *es = ExtensionSet::singleton(&ExtensionSet::TO_BE_INFERRED).missing_from(&merged); Ok(es.clone()) From cd45bc9a642ffc73d56cb93677463cf122f59779 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 13 Jun 2024 11:15:53 +0100 Subject: [PATCH 05/13] Move TO_BE_INFERRED out of ExtensionSet into top-level of extension.rs --- hugr-core/src/extension.rs | 15 +++++++++++++-- hugr-core/src/hugr.rs | 6 +++--- hugr-core/src/hugr/validate.rs | 7 ++----- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 260e33b9d..17475daad 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -405,9 +405,20 @@ pub enum ExtensionBuildError { #[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct ExtensionSet(BTreeSet); -impl ExtensionSet { - pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked("TO_BE_INFERRED"); +/// A special ExtensionId which indicates that the delta of a non-Function +/// container node should be computed by extension inference. +/// Usable only in non-Function container nodes +/// ([Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop]) +/// +/// [Case]: crate::ops::Case +/// [CFG]: crate::ops::CFG +/// [Conditional]: crate::ops::Conditional +/// [DataflowBlock]: crate::ops::DataflowBlock +/// [DFG]: crate::ops::DFG +/// [TailLoop]: crate::ops::TailLoop +pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked("TO_BE_INFERRED"); +impl ExtensionSet { /// Creates a new empty extension set. pub const fn new() -> Self { Self(BTreeSet::new()) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 8d1f47d9d..df2455015 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -24,7 +24,7 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; -use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; use crate::ops::custom::resolve_extension_ops; use crate::ops::{OpTag, OpTrait}; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; @@ -124,7 +124,7 @@ impl Hugr { let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { return Ok(h.get_optype(node).extension_delta()); }; - if es.contains(&ExtensionSet::TO_BE_INFERRED) { + if es.contains(&TO_BE_INFERRED) { // Do not remove anything from current delta - any other elements are a lower bound child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst } else if remove { @@ -143,7 +143,7 @@ impl Hugr { return Ok(es.clone()); // Can't neither add nor remove, so nothing to do } let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); - *es = ExtensionSet::singleton(&ExtensionSet::TO_BE_INFERRED).missing_from(&merged); + *es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged); Ok(es.clone()) } diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 45099fff5..8560195d9 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED}; use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; @@ -60,10 +60,7 @@ impl Hugr { pub fn validate_extensions(&self) -> Result<(), ValidationError> { for parent in self.nodes() { let parent_op = self.get_optype(parent); - if parent_op - .extension_delta() - .contains(&ExtensionSet::TO_BE_INFERRED) - { + if parent_op.extension_delta().contains(&TO_BE_INFERRED) { return Err(ValidationError::ExtensionsNotInferred { node: parent }); } let parent_extensions = match parent_op.inner_function_type() { From f3a05c4195be1d862302b452cbf4672660187f9d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 13 Jun 2024 15:23:00 +0100 Subject: [PATCH 06/13] inference tests --- hugr-core/src/hugr.rs | 169 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 161 insertions(+), 8 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index df2455015..7e7552583 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -150,13 +150,6 @@ impl Hugr { infer(self, self.root(), remove)?; Ok(()) } - - // Note: tests - // * all combinations of (remove or not, TO_BE_INFERRED present or absent, success(inferred-set) or failure (possible only if no TO_BE_INFERRED) ) - // * parent - child - grandchild tests: - // X - Y + INFER - X (ok with remove, but fails w/out remove) - // X - Y + INFER - Y or X - INFER - Y (mid fails against parent with just Y, regardless of remove) - // X - INFER - X (ok with-or-without remove) } /// Internal API for HUGRs, not intended for use by users. @@ -285,7 +278,13 @@ pub enum HugrError { #[cfg(test)] mod test { - use super::{Hugr, HugrView}; + #[cfg(feature = "extension_inference")] + use super::ValidationError; + use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; + use crate::extension::{ExtensionId, ExtensionSet, EMPTY_REG, TO_BE_INFERRED}; + use crate::types::{FunctionType, Type}; + use crate::{const_extension_ids, ops, type_row}; + use rstest::rstest; #[test] fn impls_send_and_sync() { @@ -304,4 +303,158 @@ mod test { let hugr = simple_dfg_hugr(); assert_matches!(hugr.get_io(hugr.root()), Some(_)); } + + const_extension_ids! { + const XA: ExtensionId = "EXT_A"; + const XB: ExtensionId = "EXT_B"; + } + + #[rstest] + #[case([], XA.into())] + #[case([XA], XA.into())] + #[case([XB], ExtensionSet::from_iter([XA, XB]))] + + fn infer_single_delta( + #[case] parent: impl IntoIterator, + #[values(true, false)] remove: bool, // makes no difference when inferring + #[case] result: ExtensionSet, + ) { + let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); + let (mut h, _) = build_ext_cfg(parent); + h.infer_extensions(remove).unwrap(); + assert_eq!(h, build_ext_cfg(result).0); + } + + #[test] + fn infer_removes_from_delta() { + let parent = ExtensionSet::from_iter([XA, XB]); + let mut h = build_ext_cfg(parent.clone()).0; + let backup = h.clone(); + h.infer_extensions(false).unwrap(); + assert_eq!(h, backup); // did nothing + h.infer_extensions(true).unwrap(); + assert_eq!(h, build_ext_cfg(XA.into()).0); + } + + #[test] + fn infer_bad_remove() { + let (mut h, mid) = build_ext_cfg(XB.into()); + let backup = h.clone(); + h.infer_extensions(false).unwrap(); + assert_eq!(h, backup); // did nothing + let val_res = h.validate(&EMPTY_REG); + let expected_err = ExtensionError { + parent: h.root(), + parent_extensions: XB.into(), + child: mid, + child_extensions: XA.into(), + }; + #[cfg(feature = "extension_inference")] + assert_eq!( + val_res, + Err(ValidationError::ExtensionError(expected_err.clone())) + ); + #[cfg(not(feature = "extension_inference"))] + assert!(val_res.is_ok()); + + let inf_res = h.infer_extensions(true); + assert_eq!(inf_res, Err(expected_err)); + } + + fn build_ext_cfg(parent: ExtensionSet) -> (Hugr, Node) { + let ty = Type::new_function(FunctionType::new_endo(type_row![])); + let mut h = Hugr::new(ops::DFG { + signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()), + }); + let root = h.root(); + let mid = add_inliftout(&mut h, root, ty); + (h, mid) + } + + fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node { + let inp = h.add_node_with_parent( + p, + ops::Input { + types: ty.clone().into(), + }, + ); + let out = h.add_node_with_parent( + p, + ops::Output { + types: ty.clone().into(), + }, + ); + let mid = h.add_node_with_parent( + p, + ops::Lift { + type_row: ty.into(), + new_extension: XA, + }, + ); + h.connect(inp, 0, mid, 0); + h.connect(mid, 0, out, 0); + mid + } + + #[rstest] + #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] + #[case([XA], [TO_BE_INFERRED], true, [XA])] + #[case([XB], [TO_BE_INFERRED], false, [XA])] + #[case([XB], [XA, TO_BE_INFERRED], false, [XA])] + #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] + #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] + #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] + // This one just tests removal: + #[case([XA], [XA, XB], true, [XA])] + // TODO: Consider adding a separate expected-grandparent so we can have something different? + fn infer_three_generations( + #[case] grandparent: impl IntoIterator, + #[case] parent: impl IntoIterator, + #[case] success: bool, + #[case] result: impl IntoIterator, + ) { + let ty = Type::new_function(FunctionType::new_endo(type_row![])); + let grandparent = ExtensionSet::from_iter(grandparent); + let result = ExtensionSet::from_iter(result); + let root_ty = ops::Conditional { + sum_rows: vec![type_row![]], + other_inputs: ty.clone().into(), + outputs: ty.clone().into(), + extension_delta: grandparent.clone(), + }; + let mut h = Hugr::new(root_ty.clone()); + let p = h.add_node_with_parent( + h.root(), + ops::Case { + signature: FunctionType::new_endo(ty.clone()) + .with_extension_delta(ExtensionSet::from_iter(parent)), + }, + ); + add_inliftout(&mut h, p, ty.clone()); + assert!(h.validate_extensions().is_err()); + let inf_res = h.infer_extensions(true); + if success { + assert!(inf_res.is_ok()); + let expected_p = ops::Case { + signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()), + }; + assert!(h.get_optype(p) == &expected_p.into()); + let expected_gp = ops::Conditional { + extension_delta: result, + ..root_ty + }; + assert!(h.root_type() == &expected_gp.into()) + // rest should be unchanged... + } else { + assert_eq!( + inf_res, + Err(ExtensionError { + parent: h.root(), + parent_extensions: grandparent, + child: p, + child_extensions: result + }) + ); + } + } } From bf61478e67fecc39dd340e397d80ddbb1ddf5f22 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 13:35:03 +0100 Subject: [PATCH 07/13] Make TO_BE_INFERRED an illegal ExtensionId --- hugr-core/src/extension.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 17475daad..f3d8e48be 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -416,7 +416,7 @@ pub struct ExtensionSet(BTreeSet); /// [DataflowBlock]: crate::ops::DataflowBlock /// [DFG]: crate::ops::DFG /// [TailLoop]: crate::ops::TailLoop -pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked("TO_BE_INFERRED"); +pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED"); impl ExtensionSet { /// Creates a new empty extension set. From 25027f51f6b854c8a2df63773705095bd08f9f7d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 13:37:14 +0100 Subject: [PATCH 08/13] Comment and message for ExtensionsNotInferred --- hugr-core/src/hugr/validate.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 8560195d9..c36169134 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -747,8 +747,8 @@ pub enum ValidationError { /// There are errors in the extension deltas. #[error(transparent)] ExtensionError(#[from] ExtensionError), - /// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference... - #[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for TailLoop/Conditional/CFG/DFG/BasicBlock only")] + /// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference. + #[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")] ExtensionsNotInferred { node: Node }, /// Error in a node signature #[error("Error in signature of node {node:?}: {cause}")] From 0e90ed6467eae37b40aa6f82428d8fd640311023 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 13:58:56 +0100 Subject: [PATCH 09/13] doc for infer_extensions, moving list of op types --- hugr-core/src/extension.rs | 12 +++--------- hugr-core/src/hugr.rs | 21 +++++++++++++++++++-- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index f3d8e48be..0dfd9a13b 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -406,16 +406,10 @@ pub enum ExtensionBuildError { pub struct ExtensionSet(BTreeSet); /// A special ExtensionId which indicates that the delta of a non-Function -/// container node should be computed by extension inference. -/// Usable only in non-Function container nodes -/// ([Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop]) +/// container node should be computed by extension inference. See [`infer_extensions`] +/// which lists the container nodes to which this can be applied. /// -/// [Case]: crate::ops::Case -/// [CFG]: crate::ops::CFG -/// [Conditional]: crate::ops::Conditional -/// [DataflowBlock]: crate::ops::DataflowBlock -/// [DFG]: crate::ops::DFG -/// [TailLoop]: crate::ops::TailLoop +/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED"); impl ExtensionSet { diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 7e7552583..b4f100260 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -97,8 +97,25 @@ impl Hugr { Ok(()) } - /// Leaving this here as in the future we plan for it to infer deltas - /// of container nodes e.g. [OpType::DFG]. For the moment it does nothing. + /// Infers an extension-delta for any non-function container node + /// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta + /// will be the smallest delta compatible with its children and that includes any + /// other [ExtensionId]s in the current delta. + /// + /// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED], + /// ExtensionIds are removed from the delta if they are *not* used by any child node. + /// + /// The non-function container nodes are: + /// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop] + /// + /// [Case]: crate::ops::Case + /// [CFG]: crate::ops::CFG + /// [Conditional]: crate::ops::Conditional + /// [DataflowBlock]: crate::ops::DataflowBlock + /// [DFG]: crate::ops::DFG + /// [TailLoop]: crate::ops::TailLoop + /// [extension_delta]: crate::ops::OpType::extension_delta + /// [ExtensionId]: crate::extension::ExtensionId pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { match optype { From d36e34f755ab88f6e37e6ea6caee21e4667afc5a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 14:10:21 +0100 Subject: [PATCH 10/13] rename build_ext_(c=>d)fg --- hugr-core/src/hugr.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index b4f100260..5a95f49f4 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -337,25 +337,25 @@ mod test { #[case] result: ExtensionSet, ) { let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); - let (mut h, _) = build_ext_cfg(parent); + let (mut h, _) = build_ext_dfg(parent); h.infer_extensions(remove).unwrap(); - assert_eq!(h, build_ext_cfg(result).0); + assert_eq!(h, build_ext_dfg(result).0); } #[test] fn infer_removes_from_delta() { let parent = ExtensionSet::from_iter([XA, XB]); - let mut h = build_ext_cfg(parent.clone()).0; + let mut h = build_ext_dfg(parent.clone()).0; let backup = h.clone(); h.infer_extensions(false).unwrap(); assert_eq!(h, backup); // did nothing h.infer_extensions(true).unwrap(); - assert_eq!(h, build_ext_cfg(XA.into()).0); + assert_eq!(h, build_ext_dfg(XA.into()).0); } #[test] fn infer_bad_remove() { - let (mut h, mid) = build_ext_cfg(XB.into()); + let (mut h, mid) = build_ext_dfg(XB.into()); let backup = h.clone(); h.infer_extensions(false).unwrap(); assert_eq!(h, backup); // did nothing @@ -378,7 +378,7 @@ mod test { assert_eq!(inf_res, Err(expected_err)); } - fn build_ext_cfg(parent: ExtensionSet) -> (Hugr, Node) { + fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) { let ty = Type::new_function(FunctionType::new_endo(type_row![])); let mut h = Hugr::new(ops::DFG { signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()), From cd96d781fce87ed8cc9e09841b79057a0a60b6f8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 14:24:44 +0100 Subject: [PATCH 11/13] infer_three_generations: comment each case, and reorder --- hugr-core/src/hugr.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 5a95f49f4..60ce7eab2 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -414,16 +414,22 @@ mod test { } #[rstest] - #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] + // Base case success: delta inferred for parent equals grandparent. #[case([XA], [TO_BE_INFERRED], true, [XA])] + // Success: delta inferred for parent is subset of grandparent + #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] + // Base case failure: infers [XA] for parent but grandparent has disjoint et #[case([XB], [TO_BE_INFERRED], false, [XA])] + // Failure: as previous, but extra "lower bound" on parent that has no effect #[case([XB], [XA, TO_BE_INFERRED], false, [XA])] + // Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB + #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] + // Success: grandparent includes extra XB required for parent's "lower bound" #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] - #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] + // Success: grandparent is also inferred so can include 'extra' XB from parent #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] - // This one just tests removal: + // No inference: extraneous XB in parent is removed so all become [XA]. #[case([XA], [XA, XB], true, [XA])] - // TODO: Consider adding a separate expected-grandparent so we can have something different? fn infer_three_generations( #[case] grandparent: impl IntoIterator, #[case] parent: impl IntoIterator, From 6aa019376ff3dc3f3b2ccecb8d2aaad5d274eba8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 14:28:22 +0100 Subject: [PATCH 12/13] infer_three_generations: Check expected Hugr by backing up and mutating --- hugr-core/src/hugr.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 60ce7eab2..d58a48a50 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -295,6 +295,7 @@ pub enum HugrError { #[cfg(test)] mod test { + use super::internal::HugrMutInternals; #[cfg(feature = "extension_inference")] use super::ValidationError; use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; @@ -455,19 +456,22 @@ mod test { ); add_inliftout(&mut h, p, ty.clone()); assert!(h.validate_extensions().is_err()); + let backup = h.clone(); let inf_res = h.infer_extensions(true); if success { assert!(inf_res.is_ok()); let expected_p = ops::Case { signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()), }; - assert!(h.get_optype(p) == &expected_p.into()); + let mut expected = backup; + expected.replace_op(p, expected_p).unwrap(); let expected_gp = ops::Conditional { extension_delta: result, ..root_ty }; - assert!(h.root_type() == &expected_gp.into()) - // rest should be unchanged... + expected.replace_op(h.root(), expected_gp).unwrap(); + + assert_eq!(h, expected); } else { assert_eq!( inf_res, From 4c5bf9cd3a5b540ed2f97378699e32c2631c317d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 24 Jun 2024 17:06:30 +0100 Subject: [PATCH 13/13] Comment typo --- hugr-core/src/hugr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index d58a48a50..978b575c5 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -419,7 +419,7 @@ mod test { #[case([XA], [TO_BE_INFERRED], true, [XA])] // Success: delta inferred for parent is subset of grandparent #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] - // Base case failure: infers [XA] for parent but grandparent has disjoint et + // Base case failure: infers [XA] for parent but grandparent has disjoint set #[case([XB], [TO_BE_INFERRED], false, [XA])] // Failure: as previous, but extra "lower bound" on parent that has no effect #[case([XB], [XA, TO_BE_INFERRED], false, [XA])]