Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Infer extension deltas for Case, Cfg, Conditional, DataflowBlock, Dfg, TailLoop #1195

Merged
merged 13 commits into from
Jun 24, 2024
Merged
7 changes: 7 additions & 0 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,13 @@ pub enum ExtensionBuildError {
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

/// A special ExtensionId which indicates that the delta of a non-Function
/// container node should be computed by extension inference. See [`infer_extensions`]
/// which lists the container nodes to which this can be applied.
///
/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions
pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED");

impl ExtensionSet {
/// Creates a new empty extension set.
pub const fn new() -> Self {
Expand Down
260 changes: 252 additions & 8 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use std::collections::VecDeque;
use std::iter;

pub(crate) use self::hugrmut::HugrMut;
use self::validate::ExtensionError;
pub use self::validate::ValidationError;

pub use ident::{IdentList, InvalidIdentifier};
Expand All @@ -25,9 +24,9 @@ use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::ExtensionRegistry;
use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::OpTag;
use crate::ops::{OpTag, OpTrait};
pub use crate::ops::{OpType, DEFAULT_OPTYPE};
use crate::{Direction, Node};

Expand Down Expand Up @@ -92,15 +91,80 @@ impl Hugr {
self.validate_no_extensions(extension_registry)?;
#[cfg(feature = "extension_inference")]
{
self.infer_extensions()?;
self.infer_extensions(false)?;
self.validate_extensions()?;
}
Ok(())
}

/// Leaving this here as in the future we plan for it to infer deltas
/// of container nodes e.g. [OpType::DFG]. For the moment it does nothing.
pub fn infer_extensions(&mut self) -> Result<(), ExtensionError> {
/// Infers an extension-delta for any non-function container node
/// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta
/// will be the smallest delta compatible with its children and that includes any
/// other [ExtensionId]s in the current delta.
///
/// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED],
/// ExtensionIds are removed from the delta if they are *not* used by any child node.
///
/// The non-function container nodes are:
/// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop]
///
/// [Case]: crate::ops::Case
/// [CFG]: crate::ops::CFG
/// [Conditional]: crate::ops::Conditional
/// [DataflowBlock]: crate::ops::DataflowBlock
/// [DFG]: crate::ops::DFG
/// [TailLoop]: crate::ops::TailLoop
/// [extension_delta]: crate::ops::OpType::extension_delta
/// [ExtensionId]: crate::extension::ExtensionId
pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> {
fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this (+tests) might be better in a separate file, but this is 200 lines, a lot shorter than the old infer.rs

match optype {
OpType::DFG(dfg) => Some(&mut dfg.signature.extension_reqs),
OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta),
OpType::TailLoop(tl) => Some(&mut tl.extension_delta),
OpType::CFG(cfg) => Some(&mut cfg.signature.extension_reqs),
OpType::Conditional(c) => Some(&mut c.extension_delta),
OpType::Case(c) => Some(&mut c.signature.extension_reqs),
//OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed
//OpType::FuncDefn(_) // Not at present due to the possibility of recursion
_ => None,
}
}
fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result<ExtensionSet, ExtensionError> {
let mut child_sets = h
.children(node)
.collect::<Vec<_>>() // Avoid borrowing h over recursive call
.into_iter()
.map(|ch| Ok((ch, infer(h, ch, remove)?)))
.collect::<Result<Vec<_>, _>>()?;

let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just add a get_optype_mut(Node) to HugrMut?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but

  1. This raises tricky questions of what to do for a Node that is not in the Hugr. What'll actually happen here is we'll fall back to DenseUnmanagedMap's behaviour of expanding to contain an OpType for that Node (index), which I suspect is not what we'd want on a (quasi-?)public method...so at the least, it's a bit more involved that it sounds
  2. If we do that, we should probably remove replace_op as having full &mut access is clearly superior! (replace_op(idx, op) is get_mut(idx) = op). I guess that could follow, but obviously it's a breaking change, and I think better to do that all in another PR.

I think this is fine for inside crate::hugr. (h is just self but we're in an inner-fn)

return Ok(h.get_optype(node).extension_delta());
};
if es.contains(&TO_BE_INFERRED) {
// Do not remove anything from current delta - any other elements are a lower bound
child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst
} else if remove {
child_sets.iter().try_for_each(|(ch, ch_exts)| {
if !es.is_superset(ch_exts) {
return Err(ExtensionError {
parent: node,
parent_extensions: es.clone(),
child: *ch,
child_extensions: ch_exts.clone(),
});
}
Ok(())
})?;
} else {
return Ok(es.clone()); // Can't neither add nor remove, so nothing to do
}
let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e));
*es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged);

Ok(es.clone())
}
infer(self, self.root(), remove)?;
Ok(())
}
}
Expand Down Expand Up @@ -203,6 +267,16 @@ impl Hugr {
}
}

#[derive(Debug, Clone, PartialEq, Error)]
#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
/// An error in the extension deltas.
pub struct ExtensionError {
parent: Node,
parent_extensions: ExtensionSet,
child: Node,
child_extensions: ExtensionSet,
}

/// Errors that can occur while manipulating a Hugr.
///
/// TODO: Better descriptions, not just re-exporting portgraph errors.
Expand All @@ -221,7 +295,14 @@ pub enum HugrError {

#[cfg(test)]
mod test {
use super::{Hugr, HugrView};
use super::internal::HugrMutInternals;
#[cfg(feature = "extension_inference")]
use super::ValidationError;
use super::{ExtensionError, Hugr, HugrMut, HugrView, Node};
use crate::extension::{ExtensionId, ExtensionSet, EMPTY_REG, TO_BE_INFERRED};
use crate::types::{FunctionType, Type};
use crate::{const_extension_ids, ops, type_row};
use rstest::rstest;

#[test]
fn impls_send_and_sync() {
Expand All @@ -240,4 +321,167 @@ mod test {
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

const_extension_ids! {
const XA: ExtensionId = "EXT_A";
const XB: ExtensionId = "EXT_B";
}

#[rstest]
#[case([], XA.into())]
#[case([XA], XA.into())]
#[case([XB], ExtensionSet::from_iter([XA, XB]))]

fn infer_single_delta(
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[values(true, false)] remove: bool, // makes no difference when inferring
#[case] result: ExtensionSet,
) {
let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into());
let (mut h, _) = build_ext_dfg(parent);
h.infer_extensions(remove).unwrap();
assert_eq!(h, build_ext_dfg(result).0);
}

#[test]
fn infer_removes_from_delta() {
let parent = ExtensionSet::from_iter([XA, XB]);
let mut h = build_ext_dfg(parent.clone()).0;
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); // did nothing
h.infer_extensions(true).unwrap();
assert_eq!(h, build_ext_dfg(XA.into()).0);
}

#[test]
fn infer_bad_remove() {
let (mut h, mid) = build_ext_dfg(XB.into());
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); // did nothing
let val_res = h.validate(&EMPTY_REG);
let expected_err = ExtensionError {
parent: h.root(),
parent_extensions: XB.into(),
child: mid,
child_extensions: XA.into(),
};
#[cfg(feature = "extension_inference")]
assert_eq!(
val_res,
Err(ValidationError::ExtensionError(expected_err.clone()))
);
#[cfg(not(feature = "extension_inference"))]
assert!(val_res.is_ok());

let inf_res = h.infer_extensions(true);
assert_eq!(inf_res, Err(expected_err));
}

fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let mut h = Hugr::new(ops::DFG {
signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()),
});
let root = h.root();
let mid = add_inliftout(&mut h, root, ty);
(h, mid)
}

fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node {
let inp = h.add_node_with_parent(
p,
ops::Input {
types: ty.clone().into(),
},
);
let out = h.add_node_with_parent(
p,
ops::Output {
types: ty.clone().into(),
},
);
let mid = h.add_node_with_parent(
p,
ops::Lift {
type_row: ty.into(),
new_extension: XA,
},
);
h.connect(inp, 0, mid, 0);
h.connect(mid, 0, out, 0);
mid
}

#[rstest]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think these cases would benefit from comments next to each saying briefly why the result is expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, done

// Base case success: delta inferred for parent equals grandparent.
#[case([XA], [TO_BE_INFERRED], true, [XA])]
// Success: delta inferred for parent is subset of grandparent
#[case([XA, XB], [TO_BE_INFERRED], true, [XA])]
// Base case failure: infers [XA] for parent but grandparent has disjoint set
#[case([XB], [TO_BE_INFERRED], false, [XA])]
// Failure: as previous, but extra "lower bound" on parent that has no effect
#[case([XB], [XA, TO_BE_INFERRED], false, [XA])]
// Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB
#[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])]
// Success: grandparent includes extra XB required for parent's "lower bound"
#[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])]
// Success: grandparent is also inferred so can include 'extra' XB from parent
#[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])]
// No inference: extraneous XB in parent is removed so all become [XA].
#[case([XA], [XA, XB], true, [XA])]
fn infer_three_generations(
#[case] grandparent: impl IntoIterator<Item = ExtensionId>,
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[case] success: bool,
#[case] result: impl IntoIterator<Item = ExtensionId>,
) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let grandparent = ExtensionSet::from_iter(grandparent);
let result = ExtensionSet::from_iter(result);
let root_ty = ops::Conditional {
sum_rows: vec![type_row![]],
other_inputs: ty.clone().into(),
outputs: ty.clone().into(),
extension_delta: grandparent.clone(),
};
let mut h = Hugr::new(root_ty.clone());
let p = h.add_node_with_parent(
h.root(),
ops::Case {
signature: FunctionType::new_endo(ty.clone())
.with_extension_delta(ExtensionSet::from_iter(parent)),
},
);
add_inliftout(&mut h, p, ty.clone());
assert!(h.validate_extensions().is_err());
let backup = h.clone();
let inf_res = h.infer_extensions(true);
if success {
assert!(inf_res.is_ok());
let expected_p = ops::Case {
signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()),
};
let mut expected = backup;
expected.replace_op(p, expected_p).unwrap();
let expected_gp = ops::Conditional {
extension_delta: result,
..root_ty
};
expected.replace_op(h.root(), expected_gp).unwrap();

assert_eq!(h, expected);
} else {
assert_eq!(
inf_res,
Err(ExtensionError {
parent: h.root(),
parent_extensions: grandparent,
child: p,
child_extensions: result
})
);
}
}
}
19 changes: 8 additions & 11 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use petgraph::visit::{Topo, Walker};
use portgraph::{LinkView, PortView};
use thiserror::Error;

use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError};
use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED};

use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
Expand All @@ -19,6 +19,7 @@ use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, Hugr, Node, Port};

use super::views::{HierarchyView, HugrView, SiblingGraph};
use super::ExtensionError;

/// Structure keeping track of pre-computed information used in the validation
/// process.
Expand Down Expand Up @@ -59,6 +60,9 @@ impl Hugr {
pub fn validate_extensions(&self) -> Result<(), ValidationError> {
for parent in self.nodes() {
let parent_op = self.get_optype(parent);
if parent_op.extension_delta().contains(&TO_BE_INFERRED) {
return Err(ValidationError::ExtensionsNotInferred { node: parent });
}
let parent_extensions = match parent_op.inner_function_type() {
Some(FunctionType { extension_reqs, .. }) => extension_reqs,
None => match parent_op.tag() {
Expand Down Expand Up @@ -743,6 +747,9 @@ pub enum ValidationError {
/// There are errors in the extension deltas.
#[error(transparent)]
ExtensionError(#[from] ExtensionError),
/// A node claims to still be awaiting extension inference. Perhaps it is not acted upon by inference.
#[error("Node {node:?} needs a concrete ExtensionSet - inference will provide this for Case/CFG/Conditional/DataflowBlock/DFG/TailLoop only")]
ExtensionsNotInferred { node: Node },
/// Error in a node signature
#[error("Error in signature of node {node:?}: {cause}")]
SignatureError { node: Node, cause: SignatureError },
Expand Down Expand Up @@ -815,15 +822,5 @@ pub enum InterGraphEdgeError {
},
}

#[derive(Debug, Clone, PartialEq, Error)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

breaking so needs mentioning in PR commit name + description

#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
/// An error in the extension deltas.
pub struct ExtensionError {
parent: Node,
parent_extensions: ExtensionSet,
child: Node,
child_extensions: ExtensionSet,
}

#[cfg(test)]
pub(crate) mod test;
Loading