Skip to content

Commit

Permalink
feat!: Infer extension deltas for Case, Cfg, Conditional, DataflowBlo…
Browse files Browse the repository at this point in the history
…ck, Dfg, TailLoop (#1195)

closes #640 
* Add an ExtensionId TO_BE_INFERRED
* On Case, Cfg, Conditional, DataflowBlock, Dfg, TailLoop,
`Hugr::infer_extensions` replaces TO_BE_INFERRED with the smallest set
of ExtensionIds that's correct wrt. its child nodes (i.e., the union of
the child node deltas). If there are other ExtensionIds alongside
TO_BE_INFERRED these will be kept (allowing a "lower bound" to be
specified for individual nodes)
* `Hugr::infer_extensions` also takes a `remove: bool` which, if true,
modifies deltas of the same OpTypes that do *not* have TO_BE_INFERRED,
by removing ExtensionIds. (Thus, non-inferred deltas act as upper
bounds).

BREAKING CHANGE: ExtensionError moved from hugr::validate to hugr;
Hugr::infer_extensions takes bool parameter
  • Loading branch information
acl-cqc authored Jun 24, 2024
1 parent db09193 commit ade8710
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 19 deletions.
7 changes: 7 additions & 0 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,13 @@ pub enum ExtensionBuildError {
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

/// A special ExtensionId which indicates that the delta of a non-Function
/// container node should be computed by extension inference. See [`infer_extensions`]
/// which lists the container nodes to which this can be applied.
///
/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions
pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED");

impl ExtensionSet {
/// Creates a new empty extension set.
pub const fn new() -> Self {
Expand Down
260 changes: 252 additions & 8 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use std::collections::VecDeque;
use std::iter;

pub(crate) use self::hugrmut::HugrMut;
use self::validate::ExtensionError;
pub use self::validate::ValidationError;

pub use ident::{IdentList, InvalidIdentifier};
Expand All @@ -25,9 +24,9 @@ use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::ExtensionRegistry;
use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::OpTag;
use crate::ops::{OpTag, OpTrait};
pub use crate::ops::{OpType, DEFAULT_OPTYPE};
use crate::{Direction, Node};

Expand Down Expand Up @@ -92,15 +91,80 @@ impl Hugr {
self.validate_no_extensions(extension_registry)?;
#[cfg(feature = "extension_inference")]
{
self.infer_extensions()?;
self.infer_extensions(false)?;
self.validate_extensions()?;
}
Ok(())
}

/// Leaving this here as in the future we plan for it to infer deltas
/// of container nodes e.g. [OpType::DFG]. For the moment it does nothing.
pub fn infer_extensions(&mut self) -> Result<(), ExtensionError> {
/// Infers an extension-delta for any non-function container node
/// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta
/// will be the smallest delta compatible with its children and that includes any
/// other [ExtensionId]s in the current delta.
///
/// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED],
/// ExtensionIds are removed from the delta if they are *not* used by any child node.
///
/// The non-function container nodes are:
/// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop]
///
/// [Case]: crate::ops::Case
/// [CFG]: crate::ops::CFG
/// [Conditional]: crate::ops::Conditional
/// [DataflowBlock]: crate::ops::DataflowBlock
/// [DFG]: crate::ops::DFG
/// [TailLoop]: crate::ops::TailLoop
/// [extension_delta]: crate::ops::OpType::extension_delta
/// [ExtensionId]: crate::extension::ExtensionId
pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> {
fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> {
match optype {
OpType::DFG(dfg) => Some(&mut dfg.signature.extension_reqs),
OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta),
OpType::TailLoop(tl) => Some(&mut tl.extension_delta),
OpType::CFG(cfg) => Some(&mut cfg.signature.extension_reqs),
OpType::Conditional(c) => Some(&mut c.extension_delta),
OpType::Case(c) => Some(&mut c.signature.extension_reqs),
//OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed
//OpType::FuncDefn(_) // Not at present due to the possibility of recursion
_ => None,
}
}
fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result<ExtensionSet, ExtensionError> {
let mut child_sets = h
.children(node)
.collect::<Vec<_>>() // Avoid borrowing h over recursive call
.into_iter()
.map(|ch| Ok((ch, infer(h, ch, remove)?)))
.collect::<Result<Vec<_>, _>>()?;

let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else {
return Ok(h.get_optype(node).extension_delta());
};
if es.contains(&TO_BE_INFERRED) {
// Do not remove anything from current delta - any other elements are a lower bound
child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst
} else if remove {
child_sets.iter().try_for_each(|(ch, ch_exts)| {
if !es.is_superset(ch_exts) {
return Err(ExtensionError {
parent: node,
parent_extensions: es.clone(),
child: *ch,
child_extensions: ch_exts.clone(),
});
}
Ok(())
})?;
} else {
return Ok(es.clone()); // Can't neither add nor remove, so nothing to do
}
let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e));
*es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged);

Ok(es.clone())
}
infer(self, self.root(), remove)?;
Ok(())
}
}
Expand Down Expand Up @@ -203,6 +267,16 @@ impl Hugr {
}
}

#[derive(Debug, Clone, PartialEq, Error)]
#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
/// An error in the extension deltas.
pub struct ExtensionError {
parent: Node,
parent_extensions: ExtensionSet,
child: Node,
child_extensions: ExtensionSet,
}

/// Errors that can occur while manipulating a Hugr.
///
/// TODO: Better descriptions, not just re-exporting portgraph errors.
Expand All @@ -221,7 +295,14 @@ pub enum HugrError {

#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
use super::internal::HugrMutInternals;
#[cfg(feature = "extension_inference")]
use super::ValidationError;
use super::{ExtensionError, Hugr, HugrMut, HugrView, Node};
use crate::extension::{ExtensionId, ExtensionSet, EMPTY_REG, TO_BE_INFERRED};
use crate::types::{FunctionType, Type};
use crate::{const_extension_ids, ops, type_row};
use rstest::rstest;

#[test]
fn impls_send_and_sync() {
Expand All @@ -240,4 +321,167 @@ mod test {
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

const_extension_ids! {
const XA: ExtensionId = "EXT_A";
const XB: ExtensionId = "EXT_B";
}

#[rstest]
#[case([], XA.into())]
#[case([XA], XA.into())]
#[case([XB], ExtensionSet::from_iter([XA, XB]))]

fn infer_single_delta(
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[values(true, false)] remove: bool, // makes no difference when inferring
#[case] result: ExtensionSet,
) {
let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into());
let (mut h, _) = build_ext_dfg(parent);
h.infer_extensions(remove).unwrap();
assert_eq!(h, build_ext_dfg(result).0);
}

#[test]
fn infer_removes_from_delta() {
let parent = ExtensionSet::from_iter([XA, XB]);
let mut h = build_ext_dfg(parent.clone()).0;
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); // did nothing
h.infer_extensions(true).unwrap();
assert_eq!(h, build_ext_dfg(XA.into()).0);
}

#[test]
fn infer_bad_remove() {
let (mut h, mid) = build_ext_dfg(XB.into());
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); // did nothing
let val_res = h.validate(&EMPTY_REG);
let expected_err = ExtensionError {
parent: h.root(),
parent_extensions: XB.into(),
child: mid,
child_extensions: XA.into(),
};
#[cfg(feature = "extension_inference")]
assert_eq!(
val_res,
Err(ValidationError::ExtensionError(expected_err.clone()))
);
#[cfg(not(feature = "extension_inference"))]
assert!(val_res.is_ok());

let inf_res = h.infer_extensions(true);
assert_eq!(inf_res, Err(expected_err));
}

fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let mut h = Hugr::new(ops::DFG {
signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()),
});
let root = h.root();
let mid = add_inliftout(&mut h, root, ty);
(h, mid)
}

fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node {
let inp = h.add_node_with_parent(
p,
ops::Input {
types: ty.clone().into(),
},
);
let out = h.add_node_with_parent(
p,
ops::Output {
types: ty.clone().into(),
},
);
let mid = h.add_node_with_parent(
p,
ops::Lift {
type_row: ty.into(),
new_extension: XA,
},
);
h.connect(inp, 0, mid, 0);
h.connect(mid, 0, out, 0);
mid
}

#[rstest]
// Base case success: delta inferred for parent equals grandparent.
#[case([XA], [TO_BE_INFERRED], true, [XA])]
// Success: delta inferred for parent is subset of grandparent
#[case([XA, XB], [TO_BE_INFERRED], true, [XA])]
// Base case failure: infers [XA] for parent but grandparent has disjoint set
#[case([XB], [TO_BE_INFERRED], false, [XA])]
// Failure: as previous, but extra "lower bound" on parent that has no effect
#[case([XB], [XA, TO_BE_INFERRED], false, [XA])]
// Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB
#[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])]
// Success: grandparent includes extra XB required for parent's "lower bound"
#[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])]
// Success: grandparent is also inferred so can include 'extra' XB from parent
#[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])]
// No inference: extraneous XB in parent is removed so all become [XA].
#[case([XA], [XA, XB], true, [XA])]
fn infer_three_generations(
#[case] grandparent: impl IntoIterator<Item = ExtensionId>,
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[case] success: bool,
#[case] result: impl IntoIterator<Item = ExtensionId>,
) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let grandparent = ExtensionSet::from_iter(grandparent);
let result = ExtensionSet::from_iter(result);
let root_ty = ops::Conditional {
sum_rows: vec![type_row![]],
other_inputs: ty.clone().into(),
outputs: ty.clone().into(),
extension_delta: grandparent.clone(),
};
let mut h = Hugr::new(root_ty.clone());
let p = h.add_node_with_parent(
h.root(),
ops::Case {
signature: FunctionType::new_endo(ty.clone())
.with_extension_delta(ExtensionSet::from_iter(parent)),
},
);
add_inliftout(&mut h, p, ty.clone());
assert!(h.validate_extensions().is_err());
let backup = h.clone();
let inf_res = h.infer_extensions(true);
if success {
assert!(inf_res.is_ok());
let expected_p = ops::Case {
signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()),
};
let mut expected = backup;
expected.replace_op(p, expected_p).unwrap();
let expected_gp = ops::Conditional {
extension_delta: result,
..root_ty
};
expected.replace_op(h.root(), expected_gp).unwrap();

assert_eq!(h, expected);
} else {
assert_eq!(
inf_res,
Err(ExtensionError {
parent: h.root(),
parent_extensions: grandparent,
child: p,
child_extensions: result
})
);
}
}
}
19 changes: 8 additions & 11 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError};
use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED};

use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
Expand All @@ -19,6 +19,7 @@ use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, Hugr, Node, Port};

use super::views::{HierarchyView, HugrView, SiblingGraph};
use super::ExtensionError;

/// Structure keeping track of pre-computed information used in the validation
/// process.
Expand Down Expand Up @@ -59,6 +60,9 @@ impl Hugr {
pub fn validate_extensions(&self) -> Result<(), ValidationError> {
for parent in self.nodes() {
let parent_op = self.get_optype(parent);
if parent_op.extension_delta().contains(&TO_BE_INFERRED) {
return Err(ValidationError::ExtensionsNotInferred { node: parent });
}
let parent_extensions = match parent_op.inner_function_type() {
Some(FunctionType { extension_reqs, .. }) => extension_reqs,
None => match parent_op.tag() {
Expand Down Expand Up @@ -743,6 +747,9 @@ pub enum ValidationError {
/// There are errors in the extension deltas.
#[error(transparent)]
ExtensionError(#[from] ExtensionError),
/// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference.
#[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")]
ExtensionsNotInferred { node: Node },
/// Error in a node signature
#[error("Error in signature of node {node:?}: {cause}")]
SignatureError { node: Node, cause: SignatureError },
Expand Down Expand Up @@ -815,15 +822,5 @@ pub enum InterGraphEdgeError {
},
}

#[derive(Debug, Clone, PartialEq, Error)]
#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
/// An error in the extension deltas.
pub struct ExtensionError {
parent: Node,
parent_extensions: ExtensionSet,
child: Node,
child_extensions: ExtensionSet,
}

#[cfg(test)]
pub(crate) mod test;

0 comments on commit ade8710

Please sign in to comment.