From 8bec8e93bcaa8917b00098837269da60e3312d6c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 11 Jun 2024 18:59:48 +0100 Subject: [PATCH] feat!: Validate Extensions using hierarchy, ignore input_extensions, RIP inference (#1142) This is probably phase 1 of 4, later steps being * Remove `input_extensions` and `NodeType` * Make every operation require it's own extension (#388) * Infer deltas for parent nodes (#640) - this may be disruptive, and/or take some subtlety, as currently every FunctionType stores an ExtensionSet not an Option thereof * Finally, remove the feature flag So, this PR updates validation to ignore the input_extensions, and removes the inference algorithm that would set them. There were a few complications: * A lot of tests create a DFGBuilder with an *empty* delta, then put things in it that have extension requirements. The inference algorithm figures out that this is all OK if it can just put that extension-requirement on the `input_extensions` to the DFG node (root) itself. So this might be a call for `input_extensions` - they do reduce the need for #640 a bit. * We can see roughly how painful life is (without delta-inference nor input-extensions) looking at the various extra extension-deltas I've had to specify here * I think that given we are under the feature flag at the moment, we are OK to continue, however the need for #640 is now somewhat increased, so discussion there strongly encouraged, please! :) * `TailLoop` turned out not to have an extension-delta, where it clearly needs one (just as DataflowBlock and others). Also there was an implementation of `OpParent` for it, whereas in fact we needed `DataflowParent` as that would also have got us the `ValidateOp` for free. This all in 08bb5a59caeba3c6d30a970043a952e8291b6dd4. * Another bug in DataflowBlock::inner_signature, and some others. ....That is to say, in some ways this validation scheme is *stricter* than what we had; there is both the slight flexibility that `input_extensions` gave us in mis-specifying deltas, and that this scheme's simplicity makes it much harder to accidentally omit cases from validation.... I've also kept `Hugr::infer_extensions` around as a no-op rather than remove it and later bring it back (in #640). Maybe we should remove it, but that would be a breaking change....OTOH I've removed `validate_with_extension_closure` and in theory we could have that (a closure containing deltas for nested DFGs) too... BREAKING CHANGE: TailLoop node and associated builder functions now require specifying an ExtensionSet; extension/validate.rs deleted; some changes to Hugrs validated/rejected when the `extension_inference` feature flag is turned on --- hugr-core/src/builder/build_traits.rs | 2 + hugr-core/src/builder/tail_loop.rs | 17 +- hugr-core/src/extension.rs | 7 - hugr-core/src/extension/infer.rs | 733 ----------- hugr-core/src/extension/infer/test.rs | 1083 ----------------- hugr-core/src/extension/prelude.rs | 13 +- hugr-core/src/extension/validate.rs | 209 ---- hugr-core/src/hugr.rs | 84 +- hugr-core/src/hugr/serialize/test.rs | 9 +- hugr-core/src/hugr/validate.rs | 97 +- hugr-core/src/hugr/validate/test.rs | 581 +++++---- hugr-core/src/ops.rs | 1 - hugr-core/src/ops/constant.rs | 12 +- hugr-core/src/ops/controlflow.rs | 12 +- hugr-core/src/ops/validate.rs | 25 - hugr-passes/src/const_fold.rs | 6 +- hugr-passes/src/const_fold/test.rs | 168 +-- hugr-passes/src/merge_bbs.rs | 13 +- hugr-py/src/hugr/serialization/ops.py | 1 + .../schema/hugr_schema_strict_v1.json | 7 + specification/schema/hugr_schema_v1.json | 7 + .../schema/testing_hugr_schema_strict_v1.json | 7 + .../schema/testing_hugr_schema_v1.json | 7 + 23 files changed, 514 insertions(+), 2587 deletions(-) delete mode 100644 hugr-core/src/extension/infer.rs delete mode 100644 hugr-core/src/extension/infer/test.rs delete mode 100644 hugr-core/src/extension/validate.rs diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 9703d65c3..9279a6ba5 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -429,6 +429,7 @@ pub trait Dataflow: Container { just_inputs: impl IntoIterator, inputs_outputs: impl IntoIterator, just_out_types: TypeRow, + extension_delta: ExtensionSet, ) -> Result, BuildError> { let (input_types, mut input_wires): (Vec, Vec) = just_inputs.into_iter().unzip(); @@ -440,6 +441,7 @@ pub trait Dataflow: Container { just_inputs: input_types.into(), just_outputs: just_out_types, rest: rest_types.into(), + extension_delta, }; // TODO: Make input extensions a parameter let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?; diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index 0c87d8393..2bee9bcfa 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -1,3 +1,4 @@ +use crate::extension::ExtensionSet; use crate::ops; use crate::hugr::{views::HugrView, NodeType}; @@ -74,11 +75,13 @@ impl TailLoopBuilder { just_inputs: impl Into, inputs_outputs: impl Into, just_outputs: impl Into, + extension_delta: ExtensionSet, ) -> Result { let tail_loop = ops::TailLoop { just_inputs: just_inputs.into(), just_outputs: just_outputs.into(), rest: inputs_outputs.into(), + extension_delta, }; // TODO: Allow input extensions to be specified let base = Hugr::new(NodeType::new_open(tail_loop.clone())); @@ -97,7 +100,6 @@ mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }, extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T}, - extension::ExtensionSet, hugr::ValidationError, ops::Value, type_row, @@ -107,7 +109,8 @@ mod test { #[test] fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { - let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; + let mut loop_b = + TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID.into())?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -141,8 +144,12 @@ mod test { )? .outputs_arr(); let loop_id = { - let mut loop_b = - fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; + let mut loop_b = fbuild.tail_loop_builder( + vec![(BIT, b1)], + vec![], + type_row![NAT], + PRELUDE_ID.into(), + )?; let signature = loop_b.loop_signature()?.clone(); let const_val = Value::true_val(); let const_wire = loop_b.add_load_const(Value::true_val()); @@ -161,7 +168,7 @@ mod test { ([type_row![], type_row![]], const_wire), vec![(BIT, b1)], output_row, - ExtensionSet::new(), + PRELUDE_ID.into(), )?; let mut branch_0 = conditional_b.case_builder(0)?; diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 1ef9c50cd..deecf1cbb 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -19,12 +19,6 @@ use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; use crate::types::{FunctionType, TypeNameRef}; -#[allow(dead_code)] -mod infer; -#[cfg(feature = "extension_inference")] -pub use infer::infer_extensions; -pub use infer::{ExtensionSolution, InferExtensionError}; - mod op_def; pub use op_def::{ CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, @@ -35,7 +29,6 @@ pub use type_def::{TypeDef, TypeDefBound}; mod const_fold; pub mod prelude; pub mod simple_op; -pub mod validate; pub use const_fold::{ConstFold, ConstFoldResult, Folder}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; diff --git a/hugr-core/src/extension/infer.rs b/hugr-core/src/extension/infer.rs deleted file mode 100644 index bcb8608cc..000000000 --- a/hugr-core/src/extension/infer.rs +++ /dev/null @@ -1,733 +0,0 @@ -//! Inference for extension requirements on nodes of a hugr. -//! -//! Checks if the extensions requirements have a solution in terms of some -//! number of starting variables, and comes up with concrete solutions when -//! possible. -//! -//! Open extension variables can come from toplevel nodes: notionally "inputs" -//! to the graph where being wired up to a larger hugr would provide the -//! information needed to solve variables. When extension requirements of nodes -//! depend on these open variables, then the validation check for extensions -//! will succeed regardless of what the variable is instantiated to. - -use super::ExtensionSet; -use crate::{ - hugr::views::HugrView, - ops::{OpTag, OpTrait}, - types::EdgeKind, - Direction, Node, -}; - -use super::validate::ExtensionError; - -use petgraph::graph as pg; -use petgraph::{Directed, EdgeType, Undirected}; - -use std::collections::{HashMap, HashSet, VecDeque}; - -use thiserror::Error; - -/// A mapping from nodes on the hugr to extension requirement sets which have -/// been inferred for their inputs. -pub type ExtensionSolution = HashMap; - -/// Infer extensions for a hugr. This is the main API exposed by this module. -/// -/// Return all the solutions found for locations on the graph, these can be -/// passed to [`validate_with_extension_closure`] -/// -/// [`validate_with_extension_closure`]: crate::Hugr::validate_with_extension_closure -pub fn infer_extensions(hugr: &impl HugrView) -> Result { - let mut ctx = UnificationContext::new(hugr); - ctx.main_loop()?; - ctx.instantiate_variables(); - let all_results = ctx.main_loop()?; - let new_results = all_results - .into_iter() - .filter(|(n, _sol)| hugr.get_nodetype(*n).input_extensions().is_none()) - .collect(); - Ok(new_results) -} - -/// Metavariables don't need much -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -struct Meta(u32); - -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -/// Things we know about metavariables -enum Constraint { - /// A variable has the same value as another variable - Equal(Meta), - /// Variable extends the value of another by a set of extensions - Plus(ExtensionSet, Meta), -} - -#[derive(Debug, Clone, PartialEq, Error)] -#[non_exhaustive] -/// Errors which arise during unification -pub enum InferExtensionError { - #[error("Mismatched extension sets {expected} and {actual}")] - /// We've solved a metavariable, then encountered a constraint - /// that says it should be something other than our solution - MismatchedConcrete { - /// The solution we were trying to insert for this meta - expected: ExtensionSet, - /// The incompatible solution that we found was already there - actual: ExtensionSet, - }, - #[error("Solved extensions {expected} at {expected_loc:?} and {actual} at {actual_loc:?} should be equal.")] - /// A version of the above with info about which nodes failed to unify - MismatchedConcreteWithLocations { - /// Where the solution we want to insert came from - expected_loc: (Node, Direction), - /// The solution we were trying to insert for this meta - expected: ExtensionSet, - /// Which node we're trying to add a solution for - actual_loc: (Node, Direction), - /// The incompatible solution that we found was already there - actual: ExtensionSet, - }, - /// A variable went unsolved that wasn't related to a parameter - #[error("Unsolved variable at location {:?}", location)] - Unsolved { - /// The location on the hugr that's associated to the unsolved meta - location: (Node, Direction), - }, - /// An extension mismatch between two nodes which are connected by an edge. - /// This should mirror (or reuse) `ValidationError`'s SrcExceedsTgtExtensions - /// and TgtExceedsSrcExtensions - #[error("Edge mismatch: {0}")] - EdgeMismatch(#[from] ExtensionError), -} - -/// A graph of metavariables connected by constraints. -/// The edges represent `Equal` constraints in the undirected graph and `Plus` -/// constraints in the directed case. -struct GraphContainer { - graph: pg::Graph, - node_map: HashMap, -} - -impl GraphContainer { - /// Add a metavariable to the graph as a node and return the `NodeIndex`. - /// If it's already there, just return the existing `NodeIndex` - fn add_or_retrieve(&mut self, m: Meta) -> pg::NodeIndex { - self.node_map.get(&m).cloned().unwrap_or_else(|| { - let ix = self.graph.add_node(m); - self.node_map.insert(m, ix); - ix - }) - } - - /// Create an edge between two nodes on the graph - fn add_edge(&mut self, src: Meta, tgt: Meta) { - let src_ix = self.add_or_retrieve(src); - let tgt_ix = self.add_or_retrieve(tgt); - self.graph.add_edge(src_ix, tgt_ix, ()); - } - - /// Return the strongly connected components of the graph in terms of - /// metavariables. In the undirected case, return the connected components - fn sccs(&self) -> Vec> { - petgraph::algo::tarjan_scc(&self.graph) - .into_iter() - .map(|cc| { - cc.into_iter() - .map(|n| *self.graph.node_weight(n).unwrap()) - .collect() - }) - .collect() - } -} - -impl GraphContainer { - fn new() -> Self { - GraphContainer { - graph: pg::Graph::new_undirected(), - node_map: HashMap::new(), - } - } -} - -impl GraphContainer { - fn new() -> Self { - GraphContainer { - graph: pg::Graph::new(), - node_map: HashMap::new(), - } - } -} - -type EqGraph = GraphContainer; - -/// Our current knowledge about the extensions of the graph -struct UnificationContext { - /// A list of constraints for each metavariable - constraints: HashMap>, - /// A map which says which nodes correspond to which metavariables - extensions: HashMap<(Node, Direction), Meta>, - /// Solutions to metavariables - solved: HashMap, - /// A graph which says which metavariables should be equal - eq_graph: EqGraph, - /// A mapping from metavariables which have been merged, to the meta they've - // been merged to - shunted: HashMap, - /// Variables we're allowed to include in solutionss - variables: HashSet, - /// A name for the next metavariable we create - fresh_name: u32, -} - -/// Invariant: Constraint::Plus always points to a fresh metavariable -impl UnificationContext { - /// Create a new unification context, and populate it with constraints from - /// traversing the hugr which is passed in. - fn new(hugr: &impl HugrView) -> Self { - let mut ctx = Self { - constraints: HashMap::new(), - extensions: HashMap::new(), - solved: HashMap::new(), - eq_graph: EqGraph::new(), - shunted: HashMap::new(), - variables: HashSet::new(), - fresh_name: 0, - }; - ctx.gen_constraints(hugr); - ctx - } - - /// Create a fresh metavariable, and increment `fresh_name` for next time - fn fresh_meta(&mut self) -> Meta { - let fresh = Meta(self.fresh_name); - self.fresh_name += 1; - self.constraints.insert(fresh, HashSet::new()); - fresh - } - - /// Declare a constraint on the metavariable - fn add_constraint(&mut self, m: Meta, c: Constraint) { - self.constraints.entry(m).or_default().insert(c); - } - - /// Declare that a meta has been solved - fn add_solution(&mut self, m: Meta, rs: ExtensionSet) { - let existing_sol = self.solved.insert(m, rs); - debug_assert!(existing_sol.is_none()); - } - - /// If a metavariable has been merged, return the new meta, otherwise return - /// the same meta. - /// - /// This could loop if there were a cycle in the `shunted` list, but there - /// shouldn't be, because we only ever shunt to *new* metas. - fn resolve(&self, m: Meta) -> Meta { - self.shunted.get(&m).cloned().map_or(m, |m| self.resolve(m)) - } - - /// Get the relevant constraints for a metavariable. If it's been merged, - /// get the constraints for the merged metavariable - fn get_constraints(&self, m: &Meta) -> Option<&HashSet> { - self.constraints.get(&self.resolve(*m)) - } - - /// Get the relevant solution for a metavariable. If it's been merged, get - /// the solution for the merged metavariable - fn get_solution(&self, m: &Meta) -> Option<&ExtensionSet> { - self.solved.get(&self.resolve(*m)) - } - - /// Return the metavariable corresponding to the given location on the - /// graph, either by making a new meta, or looking it up - fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta { - if let Some(m) = self.extensions.get(&(node, dir)) { - *m - } else { - let m = self.fresh_meta(); - self.extensions.insert((node, dir), m); - m - } - } - - /// Iterate over the nodes in a hugr and generate unification constraints - fn gen_constraints(&mut self, hugr: &T) - where - T: HugrView, - { - if hugr.root_type().input_extensions().is_none() { - let m_input = self.make_or_get_meta(hugr.root(), Direction::Incoming); - self.variables.insert(m_input); - } - - for node in hugr.nodes() { - let m_input = self.make_or_get_meta(node, Direction::Incoming); - let m_output = self.make_or_get_meta(node, Direction::Outgoing); - - let node_type = hugr.get_nodetype(node); - - // Add constraints for the inputs and outputs of dataflow nodes according - // to the signature of the parent node - if let Some([input, output]) = hugr.get_io(node) { - for dir in Direction::BOTH { - let m_input_node = self.make_or_get_meta(input, dir); - self.add_constraint(m_input_node, Constraint::Equal(m_input)); - let m_output_node = self.make_or_get_meta(output, dir); - // If the parent node is a FuncDefn, it will have no - // op_signature, so the Incoming and Outgoing ports will - // have equal extension requirements. - // The function that it contains, however, may have an - // extension delta, so its output shouldn't be equal to the - // FuncDefn's output. - // - // TODO: Add a constraint that the extensions of the output - // node of a FuncDefn should be those of the input node plus - // the extension delta specified in the function signature. - if node_type.tag() != OpTag::FuncDefn { - self.add_constraint(m_output_node, Constraint::Equal(m_output)); - } - } - } - - if hugr.get_optype(node).tag() == OpTag::Conditional { - for case in hugr.children(node) { - let m_case_in = self.make_or_get_meta(case, Direction::Incoming); - let m_case_out = self.make_or_get_meta(case, Direction::Outgoing); - self.add_constraint(m_case_in, Constraint::Equal(m_input)); - self.add_constraint(m_case_out, Constraint::Equal(m_output)); - } - } - - if node_type.tag() == OpTag::Cfg { - let mut children = hugr.children(node); - let entry = children.next().unwrap(); - let exit = children.next().unwrap(); - let m_entry = self.make_or_get_meta(entry, Direction::Incoming); - let m_exit = self.make_or_get_meta(exit, Direction::Outgoing); - self.add_constraint(m_input, Constraint::Equal(m_entry)); - self.add_constraint(m_output, Constraint::Equal(m_exit)); - } - - match node_type.io_extensions() { - // Input extensions are open - None => { - let delta = node_type.op().extension_delta(); - let c = if delta.is_empty() { - Constraint::Equal(m_input) - } else { - Constraint::Plus(delta, m_input) - }; - self.add_constraint(m_output, c); - } - // We have a solution for everything! - Some((input_exts, output_exts)) => { - self.add_solution(m_input, input_exts.clone()); - self.add_solution(m_output, output_exts); - } - } - } - // Separate loop so that we can assume that a metavariable has been - // added for every (Node, Direction) in the graph already. - for tgt_node in hugr.nodes() { - let sig = hugr.get_nodetype(tgt_node).op(); - // Incoming ports with an edge that should mean equal extension reqs - for port in hugr.node_inputs(tgt_node).filter(|src_port| { - let kind = sig.port_kind(*src_port); - kind.as_ref().is_some_and(EdgeKind::is_static) - || matches!(kind, Some(EdgeKind::Value(_)) | Some(EdgeKind::ControlFlow)) - }) { - let m_tgt = *self - .extensions - .get(&(tgt_node, Direction::Incoming)) - .unwrap(); - for (src_node, _) in hugr.linked_ports(tgt_node, port) { - let m_src = self - .extensions - .get(&(src_node, Direction::Outgoing)) - .unwrap(); - self.add_constraint(*m_src, Constraint::Equal(m_tgt)); - } - } - } - } - - /// When trying to unify two metas, check if they both correspond to - /// different ends of the same wire. If so, return an `ExtensionError`. - /// Otherwise check whether they both correspond to *some* location on the - /// graph and include that info the otherwise generic `MismatchedConcrete`. - fn report_mismatch( - &self, - m1: Meta, - m2: Meta, - rs1: ExtensionSet, - rs2: ExtensionSet, - ) -> InferExtensionError { - let loc1 = self - .extensions - .iter() - .find(|(_, m)| **m == m1 || self.resolve(**m) == m1) - .map(|a| a.0); - let loc2 = self - .extensions - .iter() - .find(|(_, m)| **m == m2 || self.resolve(**m) == m2) - .map(|a| a.0); - if let (Some((node1, dir1)), Some((node2, dir2))) = (loc1, loc2) { - // N.B. We're looking for the case where an equality constraint - // arose because the two locations are connected by an edge - - // If the directions are the same, they shouldn't be connected - // to each other. If the nodes are the same, there's no edge! - // - // TODO: It's still possible that the equality constraint - // arose because one node is a dataflow parent and the other - // is one of it's I/O nodes. In that case, the directions could be - // the same, and we should try to detect it - if dir1 != dir2 && node1 != node2 { - let [(src, src_rs), (tgt, tgt_rs)] = if *dir2 == Direction::Incoming { - [(node1, rs1.clone()), (node2, rs2.clone())] - } else { - [(node2, rs2.clone()), (node1, rs1.clone())] - }; - - return InferExtensionError::EdgeMismatch(if src_rs.is_subset(&tgt_rs) { - ExtensionError::TgtExceedsSrcExtensions { - from: *src, - from_extensions: src_rs, - to: *tgt, - to_extensions: tgt_rs, - } - } else { - ExtensionError::SrcExceedsTgtExtensions { - from: *src, - from_extensions: src_rs, - to: *tgt, - to_extensions: tgt_rs, - } - }); - } - } - if let (Some(loc1), Some(loc2)) = (loc1, loc2) { - InferExtensionError::MismatchedConcreteWithLocations { - expected_loc: *loc1, - expected: rs1, - actual_loc: *loc2, - actual: rs2, - } - } else { - InferExtensionError::MismatchedConcrete { - expected: rs1, - actual: rs2, - } - } - } - - /// Take a group of equal metas and merge them into a new, single meta. - /// - /// Returns the set of new metas created and the set of metas that were - /// merged. - fn merge_equal_metas(&mut self) -> Result<(HashSet, HashSet), InferExtensionError> { - let mut merged: HashSet = HashSet::new(); - let mut new_metas: HashSet = HashSet::new(); - for cc in self.eq_graph.sccs().into_iter() { - // Within a connected component everything is equal - let combined_meta = self.fresh_meta(); - for m in cc.iter() { - // The same meta shouldn't be shunted twice directly. Only - // transitively, as we still process the meta it was shunted to - if self.shunted.contains_key(m) { - continue; - } - - if let Some(cs) = self.constraints.remove(m) { - for c in cs - .iter() - .filter(|c| !matches!(c, Constraint::Equal(_))) - .cloned() - .collect::>() - .into_iter() - { - self.add_constraint(combined_meta, c.clone()); - } - merged.insert(*m); - // Record a new meta the first time that we use it; don't - // bother recording a new meta if we don't add any - // constraints. It should be safe to call this multiple times - new_metas.insert(combined_meta); - } - // Here, solved.get is equivalent to get_solution, because if - // `m` had already been shunted, we wouldn't skipped it - if let Some(solution) = self.solved.get(m) { - match self.solved.get(&combined_meta) { - Some(existing_solution) => { - if solution != existing_solution { - return Err(self.report_mismatch( - *m, - combined_meta, - solution.clone(), - existing_solution.clone(), - )); - } - } - None => { - self.solved.insert(combined_meta, solution.clone()); - } - } - } - if self.variables.contains(m) { - self.variables.insert(combined_meta); - self.variables.remove(m); - } - self.shunted.insert(*m, combined_meta); - } - } - Ok((new_metas, merged)) - } - - /// Inspect the constraints of a given metavariable and try to find a - /// solution based on those. - /// Returns whether a solution was found - fn solve_meta(&mut self, meta: Meta) -> Result { - let mut solved = false; - for c in self.get_constraints(&meta).unwrap().clone().iter() { - match c { - // Just register the equality in the EqGraph, we'll process it later - Constraint::Equal(other_meta) => { - self.eq_graph.add_edge(meta, *other_meta); - } - // N.B. If `meta` is already solved, we can't use that - // information to solve `other_meta`. This is because the Plus - // constraint only signifies a preorder. - // I.e. if meta = other_meta + 'R', it's still possible that the - // solution is meta = other_meta because we could be adding 'R' - // to a set which already contained it. - Constraint::Plus(r, other_meta) => { - if let Some(rs) = self.get_solution(other_meta) { - let rrs = rs.clone().union(r.clone()); - match self.get_solution(&meta) { - // Let's check that this is right? - Some(rs) => { - if rs != &rrs { - return Err(self.report_mismatch( - meta, - *other_meta, - rs.clone(), - rrs, - )); - } - } - None => { - self.add_solution(meta, rrs); - solved = true; - } - }; - }; - } - } - } - Ok(solved) - } - - /// Tries to return concrete extensions for each node in the graph. This only - /// works when there are no variables in the graph! - /// - /// What we really want is to give the concrete extensions where they're - /// available. When there are variables, we should leave the graph as it is, - /// but make sure that no matter what they're instantiated to, the graph - /// still makes sense (should pass the extension validation check) - fn results(&self) -> Result { - // Check that all of the metavariables associated with nodes of the - // graph are solved - let depended_upon = { - let mut h: HashMap> = HashMap::new(); - for (m, m2) in self.constraints.iter().flat_map(|(m, cs)| { - cs.iter().flat_map(|c| match c { - Constraint::Plus(_, m2) => Some((*m, self.resolve(*m2))), - _ => None, - }) - }) { - h.entry(m2).or_default().push(m); - } - h - }; - // Calculate everything dependent upon a variable. - // Note it would be better to find metas ALL of whose dependencies were (transitively) - // on variables, but this is more complex, and hard to define if there are cycles - // of PLUS constraints, so leaving that as a TODO until we've handled such cycles. - let mut depends_on_var = HashSet::new(); - let mut queue = VecDeque::from_iter(self.variables.iter()); - while let Some(m) = queue.pop_front() { - if depends_on_var.insert(m) { - if let Some(d) = depended_upon.get(m) { - queue.extend(d.iter()) - } - } - } - - let mut results: ExtensionSolution = HashMap::new(); - for (loc, meta) in self.extensions.iter() { - if let Some(rs) = self.get_solution(meta) { - if loc.1 == Direction::Incoming { - results.insert(loc.0, rs.clone()); - } - } else { - // Unsolved nodes must be unsolved because they depend on graph variables. - if !depends_on_var.contains(&self.resolve(*meta)) { - return Err(InferExtensionError::Unsolved { location: *loc }); - } - } - } - Ok(results) - } - - /// Iterates over a set of metas (the argument) and tries to solve - /// them. - /// Returns the metas that we solved - fn solve_constraints( - &mut self, - vars: &HashSet, - ) -> Result, InferExtensionError> { - let mut solved = HashSet::new(); - for m in vars.iter() { - if self.solve_meta(*m)? { - solved.insert(*m); - } - } - Ok(solved) - } - - /// Once the unification context is set up, attempt to infer ExtensionSets - /// for all of the metavariables in the `UnificationContext`. - /// - /// Return a mapping from locations in the graph to concrete `ExtensionSets` - /// where it was possible to infer them. If it wasn't possible to infer a - /// *concrete* `ExtensionSet`, e.g. if the ExtensionSet relies on an open - /// variable in the toplevel graph, don't include that location in the map - fn main_loop(&mut self) -> Result { - let mut remaining = HashSet::::from_iter(self.constraints.keys().cloned()); - - // Keep going as long as we're making progress (= merging and solving nodes) - loop { - // Try to solve metas with the information we have now. This may - // register new equalities on the EqGraph - let to_delete = self.solve_constraints(&remaining)?; - // Merge metas based on the equalities we just registered - let (new, merged) = self.merge_equal_metas()?; - // All of the metas for which we've made progress - let delta: HashSet = HashSet::from_iter(to_delete.union(&merged).cloned()); - - // Clean up dangling constraints on solved metavariables - to_delete.iter().for_each(|m| { - self.constraints.remove(m); - }); - // Remove solved and merged metas from remaining "to solve" list - delta.iter().for_each(|m| { - remaining.remove(m); - }); - - // If we made no progress, we're done! - if delta.is_empty() && new.is_empty() { - break; - } - remaining.extend(new) - } - self.results() - } - - /// Gather all the transitive dependencies (induced by constraints) of the - /// variables in the context. - fn search_variable_deps(&self) -> HashSet { - let mut seen = HashSet::new(); - let mut new_variables: HashSet = self.variables.clone(); - while !new_variables.is_empty() { - new_variables = new_variables - .into_iter() - .filter(|m| seen.insert(*m)) - .flat_map(|m| self.get_constraints(&m)) - .flatten() - .map(|c| match c { - Constraint::Plus(_, other) => self.resolve(*other), - Constraint::Equal(other) => self.resolve(*other), - }) - .collect(); - } - seen - } - - /// Instantiate all variables in the graph with the empty extension set, or - /// the smallest solution possible given their constraints. - /// This is done to solve metas which depend on variables, which allows - /// us to come up with a fully concrete solution to pass into validation. - /// - /// Nodes which loop into themselves must be considered as a "minimum" set - /// of requirements. If we have - /// 1 = 2 + X, ... - /// 2 = 1 + x, ... - /// then 1 and 2 both definitely contain X, even if we don't know what else. - /// So instead of instantiating to the empty set, we'll instantiate to `{X}` - fn instantiate_variables(&mut self) { - // A directed graph to keep track of `Plus` constraint relationships - let mut relations = GraphContainer::::new(); - let mut solutions: HashMap = HashMap::new(); - - let variable_scope = self.search_variable_deps(); - for m in variable_scope.into_iter() { - // If `m` has been merged, [`self.variables`] entry - // will have already been updated to the merged - // value by [`self.merge_equal_metas`] so we don't - // need to worry about resolving it. - if !self.solved.contains_key(&m) { - // Handle the case where the constraints for `m` contain a self - // reference, i.e. "m = Plus(E, m)", in which case the variable - // should be instantiated to E rather than the empty set. - let plus_constraints = - self.get_constraints(&m) - .unwrap() - .iter() - .cloned() - .flat_map(|c| match c { - Constraint::Plus(r, other_m) => Some((r, self.resolve(other_m))), - _ => None, - }); - - let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip(); - let solution = ExtensionSet::union_over(rs); - let unresolved_metas = other_ms - .into_iter() - .filter(|other_m| m != *other_m) - .collect::>(); - - // If `m` doesn't depend on any other metas then we have all the - // information we need to come up with a solution for it. - relations.add_or_retrieve(m); - unresolved_metas - .iter() - .for_each(|other_m| relations.add_edge(m, *other_m)); - solutions.insert(m, solution); - } - } - - // Process the strongly-connected components. petgraph/sccs() returns these - // depended-upon before dependant, as we need. - for cc in relations.sccs() { - // Strongly connected components are looping constraint dependencies. - // This means that each metavariable in the CC has the same solution. - let combined_solution = cc - .iter() - .flat_map(|m| self.get_constraints(m).unwrap()) - .filter_map(|c| match c { - Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)).cloned(), - Constraint::Equal(_) => None, - }) - .fold(ExtensionSet::new(), ExtensionSet::union); - - for m in cc.iter() { - self.add_solution(*m, combined_solution.clone()); - solutions.insert(*m, combined_solution.clone()); - } - } - self.variables = HashSet::new(); - } -} - -#[cfg(test)] -mod test; diff --git a/hugr-core/src/extension/infer/test.rs b/hugr-core/src/extension/infer/test.rs deleted file mode 100644 index 4888d0314..000000000 --- a/hugr-core/src/extension/infer/test.rs +++ /dev/null @@ -1,1083 +0,0 @@ -use std::error::Error; - -use super::*; -use crate::builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, -}; -use crate::extension::prelude::PRELUDE_REGISTRY; -use crate::extension::prelude::QB_T; -use crate::extension::ExtensionId; -use crate::hugr::{Hugr, HugrMut, NodeType}; -use crate::macros::const_extension_ids; -use crate::ops::custom::OpaqueOp; -use crate::ops::{self, dataflow::IOTrait}; -use crate::ops::{CustomOp, Lift, OpType}; -#[cfg(feature = "extension_inference")] -use crate::{ - builder::test::closed_dfg_root_hugr, - extension::prelude::PRELUDE_ID, - hugr::{internal::HugrMutInternals, validate::ValidationError}, - ops::{dataflow::DataflowParent, handle::NodeHandle}, -}; - -use crate::type_row; -use crate::types::{FunctionType, Type, TypeRow}; - -use cool_asserts::assert_matches; -use itertools::Itertools; -use portgraph::NodeIndex; - -const NAT: Type = crate::extension::prelude::USIZE_T; - -const_extension_ids! { - const A: ExtensionId = "A"; - const B: ExtensionId = "B"; - const C: ExtensionId = "C"; - const UNKNOWN_EXTENSION: ExtensionId = "Unknown"; -} - -#[test] -// Build up a graph with some holes in its extension requirements, and infer -// them. -fn from_graph() -> Result<(), Box> { - let rs = ExtensionSet::from_iter([A, B, C]); - let main_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(rs); - - let op = ops::DFG { - signature: main_sig, - }; - - let root_node = NodeType::new_open(op); - let mut hugr = Hugr::new(root_node); - - let input = ops::Input::new(type_row![NAT, NAT]); - let output = ops::Output::new(type_row![NAT]); - - let input = hugr.add_node_with_parent(hugr.root(), input); - let output = hugr.add_node_with_parent(hugr.root(), output); - - assert_matches!(hugr.get_io(hugr.root()), Some(_)); - - let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A); - - let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(B); - - let add_ab_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(ExtensionSet::from_iter([A, B])); - - let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(C); - - let add_a = hugr.add_node_with_parent( - hugr.root(), - ops::DFG { - signature: add_a_sig, - }, - ); - let add_b = hugr.add_node_with_parent( - hugr.root(), - ops::DFG { - signature: add_b_sig, - }, - ); - let add_ab = hugr.add_node_with_parent( - hugr.root(), - ops::DFG { - signature: add_ab_sig, - }, - ); - let mult_c = hugr.add_node_with_parent( - hugr.root(), - ops::DFG { - signature: mult_c_sig, - }, - ); - - hugr.connect(input, 0, add_a, 0); - hugr.connect(add_a, 0, add_b, 0); - hugr.connect(add_b, 0, mult_c, 0); - - hugr.connect(input, 1, add_ab, 0); - hugr.connect(add_ab, 0, mult_c, 1); - - hugr.connect(mult_c, 0, output, 0); - - let solution = infer_extensions(&hugr)?; - let empty = ExtensionSet::new(); - let ab = ExtensionSet::from_iter([A, B]); - assert_eq!(*solution.get(&(hugr.root())).unwrap(), empty); - assert_eq!(*solution.get(&(mult_c)).unwrap(), ab); - assert_eq!(*solution.get(&(add_ab)).unwrap(), empty); - assert_eq!(*solution.get(&add_b).unwrap(), ExtensionSet::singleton(&A)); - Ok(()) -} - -#[test] -// Basic test that the `Plus` constraint works -fn plus() -> Result<(), InferExtensionError> { - let hugr = Hugr::default(); - let mut ctx = UnificationContext::new(&hugr); - - let metas: Vec = (2..8) - .map(|i| { - let meta = ctx.fresh_meta(); - ctx.extensions - .insert((NodeIndex::new(i).into(), Direction::Incoming), meta); - meta - }) - .collect(); - - ctx.solved.insert(metas[2], A.into()); - ctx.add_constraint(metas[1], Constraint::Equal(metas[2])); - ctx.add_constraint(metas[0], Constraint::Plus(B.into(), metas[2])); - ctx.add_constraint(metas[4], Constraint::Plus(C.into(), metas[0])); - ctx.add_constraint(metas[3], Constraint::Equal(metas[4])); - ctx.add_constraint(metas[5], Constraint::Equal(metas[0])); - ctx.main_loop()?; - - let a = ExtensionSet::singleton(&A); - let mut ab = a.clone(); - ab.insert(&B); - let mut abc = ab.clone(); - abc.insert(&C); - - assert_eq!(ctx.get_solution(&metas[0]).unwrap(), &ab); - assert_eq!(ctx.get_solution(&metas[1]).unwrap(), &a); - assert_eq!(ctx.get_solution(&metas[2]).unwrap(), &a); - assert_eq!(ctx.get_solution(&metas[3]).unwrap(), &abc); - assert_eq!(ctx.get_solution(&metas[4]).unwrap(), &abc); - assert_eq!(ctx.get_solution(&metas[5]).unwrap(), &ab); - - Ok(()) -} - -#[cfg(feature = "extension_inference")] -#[test] -// This generates a solution that causes validation to fail -// because of a missing lift node -fn missing_lift_node() { - let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A), - })); - - let input = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Input { - types: type_row![NAT], - }), - ); - - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Output { - types: type_row![NAT], - }), - ); - - hugr.connect(input, 0, output, 0); - - // Fail to catch the actual error because it's a difference between I/O - // nodes and their parents and `report_mismatch` isn't yet smart enough - // to handle that. - assert_matches!( - hugr.update_validate(&PRELUDE_REGISTRY), - Err(ValidationError::CantInfer(_)) - ); -} - -#[test] -// Tests that we can succeed even when all variables don't have concrete -// extension sets, and we have an open variable at the start of the graph. -fn open_variables() -> Result<(), InferExtensionError> { - let mut ctx = UnificationContext::new(&Hugr::default()); - let a = ctx.fresh_meta(); - let b = ctx.fresh_meta(); - let ab = ctx.fresh_meta(); - // Some nonsense so that the constraints register as "live" - ctx.extensions - .insert((NodeIndex::new(2).into(), Direction::Outgoing), a); - ctx.extensions - .insert((NodeIndex::new(3).into(), Direction::Outgoing), b); - ctx.extensions - .insert((NodeIndex::new(4).into(), Direction::Incoming), ab); - ctx.variables.insert(a); - ctx.variables.insert(b); - ctx.add_constraint(ab, Constraint::Plus(A.into(), b)); - ctx.add_constraint(ab, Constraint::Plus(B.into(), a)); - let solution = ctx.main_loop()?; - // We'll only find concrete solutions for the Incoming extension reqs of - // the main node created by `Hugr::default` - assert_eq!(solution.len(), 1); - Ok(()) -} - -#[cfg(feature = "extension_inference")] -#[test] -// Infer the extensions on a child node with no inputs -fn dangling_src() -> Result<(), Box> { - let rs = ExtensionSet::singleton(&"R".try_into().unwrap()); - - let mut hugr = closed_dfg_root_hugr( - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(rs.clone()), - ); - - let [input, output] = hugr.get_io(hugr.root()).unwrap(); - let add_r_sig = - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(rs.clone()); - - let add_r = hugr.add_node_with_parent( - hugr.root(), - ops::DFG { - signature: add_r_sig, - }, - ); - - // Dangling thingy - let src_sig = FunctionType::new(type_row![], type_row![NAT]); - - let src = hugr.add_node_with_parent(hugr.root(), ops::DFG { signature: src_sig }); - - let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]); - // Mult has open extension requirements, which we should solve to be "R" - let mult = hugr.add_node_with_parent( - hugr.root(), - ops::DFG { - signature: mult_sig, - }, - ); - - hugr.connect(input, 0, add_r, 0); - hugr.connect(add_r, 0, mult, 0); - hugr.connect(src, 0, mult, 1); - hugr.connect(mult, 0, output, 0); - - hugr.infer_extensions()?; - assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs); - assert_eq!( - hugr.get_nodetype(mult.node()).io_extensions().unwrap(), - (rs.clone(), rs) - ); - Ok(()) -} - -#[test] -fn resolve_test() -> Result<(), InferExtensionError> { - let mut ctx = UnificationContext::new(&Hugr::default()); - let m0 = ctx.fresh_meta(); - let m1 = ctx.fresh_meta(); - let m2 = ctx.fresh_meta(); - ctx.add_constraint(m0, Constraint::Equal(m1)); - ctx.main_loop()?; - let mid0 = ctx.resolve(m0); - assert_eq!(ctx.resolve(m0), ctx.resolve(m1)); - ctx.add_constraint(mid0, Constraint::Equal(m2)); - ctx.main_loop()?; - assert_eq!(ctx.resolve(m0), ctx.resolve(m2)); - assert_eq!(ctx.resolve(m1), ctx.resolve(m2)); - assert!(ctx.resolve(m0) != mid0); - Ok(()) -} - -fn create_with_io( - hugr: &mut Hugr, - parent: Node, - op: impl Into, - op_sig: FunctionType, -) -> Result<[Node; 3], Box> { - let op: OpType = op.into(); - - let node = hugr.add_node_with_parent(parent, op); - let input = hugr.add_node_with_parent( - node, - ops::Input { - types: op_sig.input, - }, - ); - let output = hugr.add_node_with_parent( - node, - ops::Output { - types: op_sig.output, - }, - ); - Ok([node, input, output]) -} - -#[cfg(feature = "extension_inference")] -#[test] -fn test_conditional_inference() -> Result<(), Box> { - fn build_case( - hugr: &mut Hugr, - conditional_node: Node, - op: ops::Case, - first_ext: ExtensionId, - second_ext: ExtensionId, - ) -> Result> { - let [case, case_in, case_out] = - create_with_io(hugr, conditional_node, op.clone(), op.inner_signature())?; - - let lift1 = hugr.add_node_with_parent( - case, - Lift { - type_row: type_row![NAT], - new_extension: first_ext, - }, - ); - - let lift2 = hugr.add_node_with_parent( - case, - Lift { - type_row: type_row![NAT], - new_extension: second_ext, - }, - ); - - hugr.connect(case_in, 0, lift1, 0); - hugr.connect(lift1, 0, lift2, 0); - hugr.connect(lift2, 0, case_out, 0); - - Ok(case) - } - - let sum_rows = vec![type_row![]; 2]; - let rs = ExtensionSet::from_iter([A, B]); - - let inputs = type_row![NAT]; - let outputs = type_row![NAT]; - - let op = ops::Conditional { - sum_rows, - other_inputs: inputs.clone(), - outputs: outputs.clone(), - extension_delta: rs.clone(), - }; - - let mut hugr = Hugr::new(NodeType::new_pure(op)); - let conditional_node = hugr.root(); - - let case_op = ops::Case { - signature: FunctionType::new(inputs, outputs).with_extension_delta(rs), - }; - let case0_node = build_case(&mut hugr, conditional_node, case_op.clone(), A, B)?; - - let case1_node = build_case(&mut hugr, conditional_node, case_op, B, A)?; - - hugr.infer_extensions()?; - - for node in [case0_node, case1_node, conditional_node] { - assert_eq!( - hugr.get_nodetype(node).io_extensions().unwrap().0, - ExtensionSet::new() - ); - assert_eq!( - hugr.get_nodetype(node).io_extensions().unwrap().0, - ExtensionSet::new() - ); - } - Ok(()) -} - -#[test] -fn extension_adding_sequence() -> Result<(), Box> { - let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]); - - let mut hugr = Hugr::new(NodeType::new_open(ops::DFG { - signature: df_sig - .clone() - .with_extension_delta(ExtensionSet::from_iter([A, B])), - })); - - let root = hugr.root(); - let input = hugr.add_node_with_parent( - root, - ops::Input { - types: type_row![NAT], - }, - ); - let output = hugr.add_node_with_parent( - root, - ops::Output { - types: type_row![NAT], - }, - ); - - // Make identical dataflow nodes which add extension requirement "A" or "B" - let df_nodes: Vec = vec![A, A, B, B, A, B] - .into_iter() - .map(|ext| { - let dfg_sig = df_sig.clone().with_extension_delta(ext.clone()); - let [node, input, output] = create_with_io( - &mut hugr, - root, - ops::DFG { - signature: dfg_sig.clone(), - }, - dfg_sig, - ) - .unwrap(); - - let lift = hugr.add_node_with_parent( - node, - Lift { - type_row: type_row![NAT], - new_extension: ext, - }, - ); - - hugr.connect(input, 0, lift, 0); - hugr.connect(lift, 0, output, 0); - - node - }) - .collect(); - - // Connect nodes in order (0 -> 1 -> 2 ...) - let nodes = [vec![input], df_nodes, vec![output]].concat(); - for (src, tgt) in nodes.into_iter().tuple_windows() { - hugr.connect(src, 0, tgt, 0); - } - hugr.update_validate(&PRELUDE_REGISTRY)?; - Ok(()) -} - -fn make_opaque(extension: impl Into, signature: FunctionType) -> CustomOp { - ops::custom::OpaqueOp::new(extension.into(), "", "".into(), vec![], signature).into() -} - -fn make_block( - hugr: &mut Hugr, - bb_parent: Node, - inputs: TypeRow, - sum_rows: impl IntoIterator, - extension_delta: ExtensionSet, -) -> Result> { - let sum_rows: Vec<_> = sum_rows.into_iter().collect(); - let sum_type = Type::new_sum(sum_rows.clone()); - let dfb_sig = FunctionType::new(inputs.clone(), vec![sum_type]) - .with_extension_delta(extension_delta.clone()); - let dfb = ops::DataflowBlock { - inputs, - other_outputs: type_row![], - sum_rows, - extension_delta, - }; - let op = make_opaque(UNKNOWN_EXTENSION, dfb_sig.clone()); - - let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?; - - let dfg = hugr.add_node_with_parent(bb, op); - - hugr.connect(bb_in, 0, dfg, 0); - hugr.connect(dfg, 0, bb_out, 0); - - Ok(bb) -} - -fn oneway(ty: Type) -> Vec { - vec![Type::new_sum([vec![ty].into()])] -} - -fn twoway(ty: Type) -> Vec { - vec![Type::new_sum([vec![ty.clone()].into(), vec![ty].into()])] -} - -fn create_entry_exit( - hugr: &mut Hugr, - root: Node, - inputs: TypeRow, - entry_variants: Vec, - entry_extensions: ExtensionSet, - exit_types: impl Into, -) -> Result<([Node; 3], Node), Box> { - let entry_sum = Type::new_sum(entry_variants.clone()); - let dfb = ops::DataflowBlock { - inputs: inputs.clone(), - other_outputs: type_row![], - sum_rows: entry_variants, - extension_delta: entry_extensions, - }; - - let exit = hugr.add_node_with_parent( - root, - ops::ExitBlock { - cfg_outputs: exit_types.into(), - }, - ); - - let entry = hugr.add_node_before(exit, dfb); - let entry_in = hugr.add_node_with_parent(entry, ops::Input { types: inputs }); - let entry_out = hugr.add_node_with_parent( - entry, - ops::Output { - types: vec![entry_sum].into(), - }, - ); - - Ok(([entry, entry_in, entry_out], exit)) -} - -/// A CFG rooted hugr adding resources at each basic block. -/// Looks like this: -/// -/// +-------------+ -/// | Entry | -/// | (Adds [A]) | -/// +-/---------\-+ -/// / \ -/// +-------/-----+ +-\-------------+ -/// | BB0 | | BB1 | -/// | (Adds [BC]) | | (Adds [B]) | -/// +----\--------+ +---/------\----+ -/// \ / \ -/// \ / \ -/// \ +----/-------+ +-\---------+ -/// \ | BB10 | | BB11 | -/// \ | (Adds [C]) | | (Adds [C])| -/// \ +----+-------+ +/----------+ -/// \ | / -/// +-----\-------+---------/-+ -/// | Exit | -/// +-------------------------+ -#[test] -fn infer_cfg_test() -> Result<(), Box> { - let abc = ExtensionSet::from_iter([A, B, C]); - let bc = ExtensionSet::from_iter([B, C]); - - let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(abc), - })); - - let root = hugr.root(); - - let ([entry, entry_in, entry_out], exit) = create_entry_exit( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT], type_row![NAT]], - A.into(), - type_row![NAT], - )?; - - let mkpred = hugr.add_node_with_parent( - entry, - make_opaque( - A, - FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(A), - ), - ); - - // Internal wiring for DFGs - hugr.connect(entry_in, 0, mkpred, 0); - hugr.connect(mkpred, 0, entry_out, 0); - - let bb0 = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT]], - bc.clone(), - )?; - - let bb1 = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT], type_row![NAT]], - B.into(), - )?; - - let bb10 = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT]], - C.into(), - )?; - - let bb11 = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT]], - C.into(), - )?; - - // CFG Wiring - hugr.connect(entry, 0, bb0, 0); - hugr.connect(entry, 0, bb1, 0); - hugr.connect(bb1, 0, bb10, 0); - hugr.connect(bb1, 0, bb11, 0); - - hugr.connect(bb0, 0, exit, 0); - hugr.connect(bb10, 0, exit, 0); - hugr.connect(bb11, 0, exit, 0); - - hugr.infer_extensions()?; - - Ok(()) -} - -/// A test case for a CFG with a node (BB2) which has multiple predecessors, -/// Like so: -/// -/// +-----------------+ -/// | Entry | -/// +------/--\-------+ -/// / \ -/// / \ -/// / \ -/// +---------/--+ +----\-------+ -/// | BB0 | | BB1 | -/// +--------\---+ +----/-------+ -/// \ / -/// \ / -/// \ / -/// +------\---/--------+ -/// | BB2 | -/// +---------+---------+ -/// | -/// +---------+----------+ -/// | Exit | -/// +--------------------+ -#[test] -fn multi_entry() -> Result<(), Box> { - let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions? - })); - let cfg = hugr.root(); - let ([entry, entry_in, entry_out], exit) = create_entry_exit( - &mut hugr, - cfg, - type_row![NAT], - vec![type_row![NAT], type_row![NAT]], - ExtensionSet::new(), - type_row![NAT], - )?; - - let entry_mid = hugr.add_node_with_parent( - entry, - make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))), - ); - - hugr.connect(entry_in, 0, entry_mid, 0); - hugr.connect(entry_mid, 0, entry_out, 0); - - let bb0 = make_block( - &mut hugr, - cfg, - type_row![NAT], - vec![type_row![NAT]], - ExtensionSet::new(), - )?; - - let bb1 = make_block( - &mut hugr, - cfg, - type_row![NAT], - vec![type_row![NAT]], - ExtensionSet::new(), - )?; - - let bb2 = make_block( - &mut hugr, - cfg, - type_row![NAT], - vec![type_row![NAT]], - ExtensionSet::new(), - )?; - - hugr.connect(entry, 0, bb0, 0); - hugr.connect(entry, 1, bb1, 0); - hugr.connect(bb0, 0, bb2, 0); - hugr.connect(bb1, 0, bb2, 0); - hugr.connect(bb2, 0, exit, 0); - - hugr.update_validate(&PRELUDE_REGISTRY)?; - - Ok(()) -} - -/// Create a CFG of the form below, with the extension deltas for `Entry`, -/// `BB1`, and `BB2` specified by arguments to the function. -/// -/// +-----------+ -/// +--->| Entry | -/// | +-----+-----+ -/// | | -/// | V -/// | +------------+ -/// | | BB1 +---+ -/// | +-----+------+ | -/// | | | -/// | V | -/// | +------------+ | -/// +----+ BB2 | | -/// +------------+ | -/// | -/// +------------+ | -/// | Exit |<--+ -/// +------------+ -fn make_looping_cfg( - entry_ext: ExtensionSet, - bb1_ext: ExtensionSet, - bb2_ext: ExtensionSet, -) -> Result> { - let hugr_delta = entry_ext - .clone() - .union(bb1_ext.clone()) - .union(bb2_ext.clone()); - - let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(hugr_delta), - })); - - let root = hugr.root(); - - let ([entry, entry_in, entry_out], exit) = create_entry_exit( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT]], - entry_ext.clone(), - type_row![NAT], - )?; - - let entry_dfg = hugr.add_node_with_parent( - entry, - make_opaque( - UNKNOWN_EXTENSION, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(entry_ext), - ), - ); - - hugr.connect(entry_in, 0, entry_dfg, 0); - hugr.connect(entry_dfg, 0, entry_out, 0); - - let bb1 = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT], type_row![NAT]], - bb1_ext.clone(), - )?; - - let bb2 = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT]], - bb2_ext.clone(), - )?; - - hugr.connect(entry, 0, bb1, 0); - hugr.connect(bb1, 0, bb2, 0); - hugr.connect(bb1, 1, exit, 0); - hugr.connect(bb2, 0, entry, 0); - - Ok(hugr) -} - -#[test] -fn test_cfg_loops() -> Result<(), Box> { - let just_a = ExtensionSet::singleton(&A); - let mut variants = Vec::new(); - for entry in [ExtensionSet::new(), just_a.clone()] { - for bb1 in [ExtensionSet::new(), just_a.clone()] { - for bb2 in [ExtensionSet::new(), just_a.clone()] { - variants.push((entry.clone(), bb1.clone(), bb2.clone())); - } - } - } - for (bb0, bb1, bb2) in variants.into_iter() { - let mut hugr = make_looping_cfg(bb0, bb1, bb2)?; - hugr.update_validate(&PRELUDE_REGISTRY)?; - } - Ok(()) -} - -#[test] -#[cfg(feature = "extension_inference")] -fn test_validate_with_closure() -> Result<(), Box> { - fn dfg_hugr_with_exts(e: Option) -> (Hugr, Node, Node) { - let mut h = closed_dfg_root_hugr(FunctionType::new_endo(type_row![QB_T])); - h.replace_op(h.root(), NodeType::new(h.get_optype(h.root()).clone(), e)) - .unwrap(); - let [input, output] = h.get_io(h.root()).unwrap(); - (h, input, output) - } - fn identity_hugr_with_exts(e: Option) -> Hugr { - let (mut h, input, output) = dfg_hugr_with_exts(e); - h.connect(input, 0, output, 0); - h - } - - const EXT_ID: ExtensionId = ExtensionId::new_unchecked("foo"); - - let inner_open = identity_hugr_with_exts(None); - - let inner_prelude = identity_hugr_with_exts(Some(ExtensionSet::singleton(&PRELUDE_ID))); - - let inner_other = identity_hugr_with_exts(Some(ExtensionSet::singleton(&EXT_ID))); - - // All three can be inferred and validated, without writing solutions in: - for inner in [&inner_open, &inner_prelude, &inner_other] { - assert_matches!( - inner.validate(&PRELUDE_REGISTRY), - Err(ValidationError::ExtensionError(_)) - ); - - let soln = infer_extensions(inner)?; - inner.validate_with_extension_closure(soln, &PRELUDE_REGISTRY)?; - } - - // Helper builds a Hugr with extensions {PRELUDE_ID}, around argument - let build_outer_prelude = |inner: Hugr| -> Hugr { - let (mut h, input, output) = dfg_hugr_with_exts(Some(ExtensionSet::singleton(&PRELUDE_ID))); - let inner_node = h.insert_hugr(h.root(), inner).new_root; - h.connect(input, 0, inner_node, 0); - h.connect(inner_node, 0, output, 0); - h - }; - - // Building a Hugr around the inner DFG works if the inner DFG is open, - // or has the correct (prelude) extensions: - for inner in [&inner_open, &inner_prelude] { - let mut h = build_outer_prelude(inner.clone()); - h.update_validate(&PRELUDE_REGISTRY)?; - } - - // ...but fails if the inner DFG already has the 'wrong' extensions: - assert_matches!( - build_outer_prelude(inner_other.clone()).update_validate(&PRELUDE_REGISTRY), - Err(ValidationError::CantInfer(_)) - ); - - // If we do inference on the inner Hugr first, this (still) works if the - // inner DFG already had the correct input-extensions: - let mut inner_prelude_inferred = inner_prelude; - inner_prelude_inferred.update_validate(&PRELUDE_REGISTRY)?; - build_outer_prelude(inner_prelude_inferred).update_validate(&PRELUDE_REGISTRY)?; - - // But fails for previously-open inner DFG as inference - // infers an incorrect (empty) solution: - let mut inner_inferred = inner_open; - inner_inferred.update_validate(&PRELUDE_REGISTRY)?; - assert_matches!( - build_outer_prelude(inner_inferred).update_validate(&PRELUDE_REGISTRY), - Err(ValidationError::CantInfer(_)) - ); - - Ok(()) -} - -#[test] -/// A control flow graph consisting of an entry node and a single block -/// which adds a resource and links to both itself and the exit node. -fn simple_cfg_loop() -> Result<(), Box> { - let just_a = ExtensionSet::singleton(&A); - - let mut hugr = Hugr::new(NodeType::new( - ops::CFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(A), - }, - Some(A.into()), - )); - - let root = hugr.root(); - - let ([entry, entry_in, entry_out], exit) = create_entry_exit( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT]], - ExtensionSet::new(), - type_row![NAT], - )?; - - let entry_mid = hugr.add_node_with_parent( - entry, - make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))), - ); - - hugr.connect(entry_in, 0, entry_mid, 0); - hugr.connect(entry_mid, 0, entry_out, 0); - - let bb = make_block( - &mut hugr, - root, - type_row![NAT], - vec![type_row![NAT], type_row![NAT]], - just_a.clone(), - )?; - - hugr.connect(entry, 0, bb, 0); - hugr.connect(bb, 0, bb, 0); - hugr.connect(bb, 1, exit, 0); - - hugr.update_validate(&PRELUDE_REGISTRY)?; - - Ok(()) -} - -/// This was stack-overflowing approx 50% of the time, -/// see https://github.com/CQCL/hugr/issues/633 -#[test] -fn plus_on_self() -> Result<(), Box> { - let ext = ExtensionId::new("unknown1").unwrap(); - let ft = FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(ext.clone()); - let mut dfg = DFGBuilder::new(ft.clone())?; - - // While https://github.com/CQCL/hugr/issues/388 is unsolved, - // most operations have empty extension_reqs (not including their own extension). - // Define some that do. - let binop = CustomOp::new_opaque(OpaqueOp::new( - ext.clone(), - "2qb_op", - String::new(), - vec![], - ft, - )); - let unary_sig = FunctionType::new_endo(type_row![QB_T]).with_extension_delta(ext.clone()); - let unop = CustomOp::new_opaque(OpaqueOp::new( - ext, - "1qb_op", - String::new(), - vec![], - unary_sig, - )); - // Constrain q1,q2 as PLUS(ext1, inputs): - let [q1, q2] = dfg - .add_dataflow_op(binop.clone(), dfg.input_wires())? - .outputs_arr(); - // Constrain q1 as PLUS(ext2, q2): - let [q1] = dfg.add_dataflow_op(unop, [q1])?.outputs_arr(); - // Constrain q1 as EQUALS(q2) by using both together - dfg.finish_hugr_with_outputs([q1, q2], &PRELUDE_REGISTRY)?; - // The combined q1+q2 variable now has two PLUS constraints - on itself and the inputs. - Ok(()) -} - -/// [plus_on_self] had about a 50% rate of failing with stack overflow. -/// So if we run 10 times, that would succeed about 1 run in 2^10, i.e. <0.1% -#[test] -fn plus_on_self_10_times() { - [0; 10].iter().for_each(|_| plus_on_self().unwrap()) -} - -#[test] -// Test that logic for dealing with self-referential constraints doesn't -// fall over when a self-referencing group of metas also references a meta -// outside the group -fn sccs() { - let hugr = Hugr::default(); - let mut ctx = UnificationContext::new(&hugr); - // Make a strongly-connected component (loop) - let m1 = ctx.fresh_meta(); - let m2 = ctx.fresh_meta(); - let m3 = ctx.fresh_meta(); - ctx.add_constraint(m1, Constraint::Plus(ExtensionSet::singleton(&A), m3)); - ctx.add_constraint(m2, Constraint::Plus(ExtensionSet::singleton(&B), m1)); - ctx.add_constraint(m3, Constraint::Plus(ExtensionSet::singleton(&A), m2)); - // And a second scc - let m4 = ctx.fresh_meta(); - let m5 = ctx.fresh_meta(); - ctx.add_constraint(m4, Constraint::Plus(ExtensionSet::singleton(&C), m5)); - ctx.add_constraint(m5, Constraint::Plus(ExtensionSet::singleton(&C), m4)); - // Make second component depend upon first - ctx.add_constraint( - m4, - Constraint::Plus(ExtensionSet::singleton(&UNKNOWN_EXTENSION), m3), - ); - ctx.variables.insert(m1); - ctx.variables.insert(m4); - ctx.instantiate_variables(); - assert_eq!( - ctx.get_solution(&m1), - Some(&ExtensionSet::from_iter([A, B])) - ); - assert_eq!( - ctx.get_solution(&m4), - Some(&ExtensionSet::from_iter([A, B, C, UNKNOWN_EXTENSION])) - ); -} - -#[test] -/// Note: This test is relying on the builder's `define_function` doing the -/// right thing: it takes input resources via a [`Signature`], which it passes -/// to `create_with_io`, creating concrete resource sets. -/// Inference can still fail for a valid FuncDefn hugr created without using -/// the builder API. -fn simple_funcdefn() -> Result<(), Box> { - let mut builder = ModuleBuilder::new(); - let mut func_builder = builder.define_function( - "F", - FunctionType::new(vec![NAT], vec![NAT]) - .with_extension_delta(A) - .into(), - )?; - - let [w] = func_builder.input_wires_arr(); - let lift = func_builder.add_dataflow_op( - Lift { - type_row: type_row![NAT], - new_extension: A, - }, - [w], - )?; - let [w] = lift.outputs_arr(); - func_builder.finish_with_outputs([w])?; - builder.finish_prelude_hugr()?; - Ok(()) -} - -#[cfg(feature = "extension_inference")] -#[test] -fn funcdefn_signature_mismatch() -> Result<(), Box> { - let mut builder = ModuleBuilder::new(); - let mut func_builder = builder.define_function( - "F", - FunctionType::new(vec![NAT], vec![NAT]) - .with_extension_delta(A) - .into(), - )?; - - let [w] = func_builder.input_wires_arr(); - let lift = func_builder.add_dataflow_op( - Lift { - type_row: type_row![NAT], - new_extension: B, - }, - [w], - )?; - let [w] = lift.outputs_arr(); - func_builder.finish_with_outputs([w])?; - let result = builder.finish_prelude_hugr(); - assert_matches!( - result, - Err(ValidationError::CantInfer( - InferExtensionError::MismatchedConcreteWithLocations { .. } - )) - ); - Ok(()) -} - -#[cfg(feature = "extension_inference")] -#[test] -// Test that the difference between a FuncDefn's input and output nodes is being -// constrained to be the same as the extension delta in the FuncDefn signature. -// The FuncDefn here is declared to add resource "A", but its body just wires -// the input to the output. -fn funcdefn_signature_mismatch2() -> Result<(), Box> { - let mut builder = ModuleBuilder::new(); - let func_builder = builder.define_function( - "F", - FunctionType::new(vec![NAT], vec![NAT]) - .with_extension_delta(A) - .into(), - )?; - - let [w] = func_builder.input_wires_arr(); - func_builder.finish_with_outputs([w])?; - let result = builder.finish_prelude_hugr(); - assert_matches!(result, Err(ValidationError::CantInfer(..))); - Ok(()) -} diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index d30c9f88b..ad96837e2 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -452,7 +452,9 @@ mod test { assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); - let mut b = DFGBuilder::new(FunctionType::new_endo(type_row![])).unwrap(); + let mut b = + DFGBuilder::new(FunctionType::new_endo(type_row![]).with_extension_delta(PRELUDE_ID)) + .unwrap(); let err = b.add_load_value(error_val); @@ -486,7 +488,10 @@ mod test { ) .unwrap(); - let mut b = DFGBuilder::new(FunctionType::new_endo(type_row![QB_T, QB_T])).unwrap(); + let mut b = DFGBuilder::new( + FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(PRELUDE_ID), + ) + .unwrap(); let [q0, q1] = b.input_wires_arr(); let [q0, q1] = b .add_dataflow_op(cx_gate(), [q0, q1]) @@ -524,7 +529,9 @@ mod test { #[test] /// Test print operation fn test_print() { - let mut b: DFGBuilder = DFGBuilder::new(FunctionType::new(vec![], vec![])).unwrap(); + let mut b: DFGBuilder = + DFGBuilder::new(FunctionType::new_endo(vec![]).with_extension_delta(PRELUDE_ID)) + .unwrap(); let greeting: ConstString = ConstString::new("Hello, world!".into()); let greeting_out: Wire = b.add_load_value(greeting); let print_op = PRELUDE diff --git a/hugr-core/src/extension/validate.rs b/hugr-core/src/extension/validate.rs deleted file mode 100644 index 246e4c3a3..000000000 --- a/hugr-core/src/extension/validate.rs +++ /dev/null @@ -1,209 +0,0 @@ -//! Validation routines for instantiations of a extension ops and types in a -//! Hugr. - -use std::collections::HashMap; - -use thiserror::Error; - -use super::{ExtensionSet, ExtensionSolution}; -use crate::hugr::NodeType; -use crate::{Direction, Hugr, HugrView, Node, Port}; - -/// Context for validating the extension requirements defined in a Hugr. -#[derive(Debug, Clone, Default)] -pub struct ExtensionValidator { - /// Extension requirements associated with each edge - extensions: HashMap<(Node, Direction), ExtensionSet>, -} - -impl ExtensionValidator { - /// Initialise a new extension validator, pre-computing the extension - /// requirements for each node in the Hugr. - /// - /// The `closure` argument is a set of extensions which doesn't actually - /// live on the graph, but is used to close the graph for validation - pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self { - let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new(); - for (node, incoming_sol) in closure.into_iter() { - let extension_reqs = hugr - .get_nodetype(node) - .op_signature() - .map(|s| s.extension_reqs) - .unwrap_or_default(); - - let outgoing_sol = extension_reqs.union(incoming_sol.clone()); - - extensions.insert((node, Direction::Incoming), incoming_sol); - extensions.insert((node, Direction::Outgoing), outgoing_sol); - } - - let mut validator = ExtensionValidator { extensions }; - - for node in hugr.nodes() { - validator.gather_extensions(&node, hugr.get_nodetype(node)); - } - - validator - } - - /// Use the signature supplied by a dataflow node to work out the - /// extension requirements for all of its input and output edges, then put - /// those requirements in the extension validation context. - fn gather_extensions(&mut self, node: &Node, node_type: &NodeType) { - if let Some((input_exts, output_exts)) = node_type.io_extensions() { - let prev_i = self - .extensions - .insert((*node, Direction::Incoming), input_exts.clone()); - assert!(prev_i.is_none()); - let prev_o = self - .extensions - .insert((*node, Direction::Outgoing), output_exts); - assert!(prev_o.is_none()); - } - } - - /// Get the input or output extension requirements for a particular node in the Hugr. - /// - /// # Errors - /// - /// If the node extensions are missing. - fn query_extensions( - &self, - node: Node, - dir: Direction, - ) -> Result<&ExtensionSet, ExtensionError> { - self.extensions - .get(&(node, dir)) - .ok_or(ExtensionError::MissingInputExtensions(node)) - } - - /// Check that two `PortIndex` have compatible extension requirements, - /// according to the information accumulated by `gather_extensions`. - /// - /// This extension checking assumes that free extension variables - /// (e.g. implicit lifting of `A -> B` to `[R]A -> [R]B`) - /// and adding of lift nodes - /// (i.e. those which transform an edge from `A` to `[R]A`) - /// has already been done. - pub fn check_extensions_compatible( - &self, - src: &(Node, Port), - tgt: &(Node, Port), - ) -> Result<(), ExtensionError> { - let rs_src = self.query_extensions(src.0, Direction::Outgoing)?; - let rs_tgt = self.query_extensions(tgt.0, Direction::Incoming)?; - - if rs_src == rs_tgt { - Ok(()) - } else if rs_src.is_subset(rs_tgt) { - // The extra extension requirements reside in the target node. - // If so, we can fix this mismatch with a lift node - Err(ExtensionError::TgtExceedsSrcExtensionsAtPort { - from: src.0, - from_offset: src.1, - from_extensions: rs_src.clone(), - to: tgt.0, - to_offset: tgt.1, - to_extensions: rs_tgt.clone(), - }) - } else { - Err(ExtensionError::SrcExceedsTgtExtensionsAtPort { - from: src.0, - from_offset: src.1, - from_extensions: rs_src.clone(), - to: tgt.0, - to_offset: tgt.1, - to_extensions: rs_tgt.clone(), - }) - } - } - - /// Check that a pair of input and output nodes declare the same extensions - /// as in the signature of their parents. - #[allow(unused_variables)] - pub fn validate_io_extensions( - &self, - parent: Node, - input: Node, - output: Node, - ) -> Result<(), ExtensionError> { - #[cfg(feature = "extension_inference")] - { - let parent_input_extensions = self.query_extensions(parent, Direction::Incoming)?; - let parent_output_extensions = self.query_extensions(parent, Direction::Outgoing)?; - for dir in Direction::BOTH { - let input_extensions = self.query_extensions(input, dir)?; - let output_extensions = self.query_extensions(output, dir)?; - if parent_input_extensions != input_extensions { - return Err(ExtensionError::ParentIOExtensionMismatch { - parent, - parent_extensions: parent_input_extensions.clone(), - child: input, - child_extensions: input_extensions.clone(), - }); - }; - if parent_output_extensions != output_extensions { - return Err(ExtensionError::ParentIOExtensionMismatch { - parent, - parent_extensions: parent_output_extensions.clone(), - child: output, - child_extensions: output_extensions.clone(), - }); - }; - } - } - Ok(()) - } -} - -/// Errors that can occur while validating a Hugr. -#[derive(Debug, Clone, PartialEq, Error)] -#[allow(missing_docs)] -#[non_exhaustive] -pub enum ExtensionError { - /// Missing lift node - #[error("Extensions at target node {to:?} ({to_extensions}) exceed those at source {from:?} ({from_extensions})")] - TgtExceedsSrcExtensions { - from: Node, - from_extensions: ExtensionSet, - to: Node, - to_extensions: ExtensionSet, - }, - /// A version of the above which includes port info - #[error("Extensions at target node {to:?} ({to_offset:?}) ({to_extensions}) exceed those at source {from:?} ({from_offset:?}) ({from_extensions})")] - TgtExceedsSrcExtensionsAtPort { - from: Node, - from_offset: Port, - from_extensions: ExtensionSet, - to: Node, - to_offset: Port, - to_extensions: ExtensionSet, - }, - /// Too many extension requirements coming from src - #[error("Extensions at source node {from:?} ({from_extensions}) exceed those at target {to:?} ({to_extensions})")] - SrcExceedsTgtExtensions { - from: Node, - from_extensions: ExtensionSet, - to: Node, - to_extensions: ExtensionSet, - }, - /// A version of the above which includes port info - #[error("Extensions at source node {from:?} ({from_offset:?}) ({from_extensions}) exceed those at target {to:?} ({to_offset:?}) ({to_extensions})")] - SrcExceedsTgtExtensionsAtPort { - from: Node, - from_offset: Port, - from_extensions: ExtensionSet, - to: Node, - to_offset: Port, - to_extensions: ExtensionSet, - }, - #[error("Missing input extensions for node {0:?}")] - MissingInputExtensions(Node), - #[error("Extensions of I/O node ({child:?}) {child_extensions:?} don't match those expected by parent node ({parent:?}): {parent_extensions:?}")] - ParentIOExtensionMismatch { - parent: Node, - parent_extensions: ExtensionSet, - child: Node, - child_extensions: ExtensionSet, - }, -} diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 0319df0dc..0ab25f7ba 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -9,12 +9,11 @@ pub mod serialize; pub mod validate; pub mod views; -#[cfg(feature = "extension_inference")] -use std::collections::HashMap; use std::collections::VecDeque; use std::iter; pub(crate) use self::hugrmut::HugrMut; +use self::validate::ExtensionError; pub use self::validate::ValidationError; pub use ident::{IdentList, InvalidIdentifier}; @@ -26,9 +25,7 @@ use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; -#[cfg(feature = "extension_inference")] -use crate::extension::infer_extensions; -use crate::extension::{ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError}; +use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::custom::resolve_extension_ops; use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE}; use crate::types::FunctionType; @@ -123,16 +120,6 @@ impl NodeType { self.input_extensions.as_ref() } - /// The input and output extensions for this node, if set. - /// - /// `None`` if the [Self::input_extensions] is `None`. - /// Otherwise, will return Some, with the output extensions computed from the node's delta - pub fn io_extensions(&self) -> Option<(ExtensionSet, ExtensionSet)> { - self.input_extensions - .clone() - .map(|e| (e.clone(), self.op.extension_delta().union(e))) - } - /// Gets the underlying [OpType] i.e. without any [input_extensions] /// /// [input_extensions]: NodeType::input_extensions @@ -208,37 +195,16 @@ impl Hugr { #[cfg(feature = "extension_inference")] { self.infer_extensions()?; - self.validate_extensions(HashMap::new())?; + self.validate_extensions()?; } Ok(()) } - /// Infer extension requirements and add new information to `op_types` field - /// (if the "extension_inference" feature is on; otherwise, do nothing) - pub fn infer_extensions(&mut self) -> Result<(), InferExtensionError> { - #[cfg(feature = "extension_inference")] - { - let solution = infer_extensions(self)?; - self.instantiate_extensions(&solution); - } + /// Leaving this here as in the future we plan for it to infer deltas + /// of container nodes e.g. [OpType::DFG]. For the moment it does nothing. + pub fn infer_extensions(&mut self) -> Result<(), ExtensionError> { Ok(()) } - - #[allow(dead_code)] - /// Add extension requirement information to the hugr in place. - fn instantiate_extensions(&mut self, solution: &ExtensionSolution) { - // We only care about inferred _input_ extensions, because `NodeType` - // uses those to infer the output extensions - for (node, input_extensions) in solution.iter() { - let nodetype = self.op_types.try_get_mut(node.pg_index()).unwrap(); - match &nodetype.input_extensions { - None => nodetype.input_extensions = Some(input_extensions.clone()), - Some(existing_ext_reqs) => { - debug_assert_eq!(existing_ext_reqs, input_extensions) - } - } - } - } } /// Internal API for HUGRs, not intended for use by users. @@ -359,8 +325,6 @@ pub enum HugrError { #[cfg(test)] mod test { use super::{Hugr, HugrView}; - #[cfg(feature = "extension_inference")] - use std::error::Error; #[test] fn impls_send_and_sync() { @@ -379,40 +343,4 @@ mod test { let hugr = simple_dfg_hugr(); assert_matches!(hugr.get_io(hugr.root()), Some(_)); } - - #[cfg(feature = "extension_inference")] - #[test] - fn extension_instantiation() -> Result<(), Box> { - use crate::builder::test::closed_dfg_root_hugr; - use crate::extension::ExtensionSet; - use crate::hugr::HugrMut; - use crate::ops::Lift; - use crate::type_row; - use crate::types::{FunctionType, Type}; - - const BIT: Type = crate::extension::prelude::USIZE_T; - let r = ExtensionSet::singleton(&"R".try_into().unwrap()); - - let mut hugr = closed_dfg_root_hugr( - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(r.clone()), - ); - let [input, output] = hugr.get_io(hugr.root()).unwrap(); - let lift = hugr.add_node_with_parent( - hugr.root(), - Lift { - type_row: type_row![BIT], - new_extension: "R".try_into().unwrap(), - }, - ); - hugr.connect(input, 0, lift, 0); - hugr.connect(lift, 0, output, 0); - hugr.infer_extensions()?; - - assert_eq!( - hugr.get_nodetype(lift).input_extensions().unwrap(), - &ExtensionSet::new() - ); - assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r); - Ok(()) - } } diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index f50556498..fdd58e398 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -11,7 +11,7 @@ use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG}; use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; -use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES}; +use crate::std_extensions::arithmetic::int_types::{self, int_custom_type, ConstInt, INT_TYPES}; use crate::std_extensions::logic::NotOp; use crate::types::{ type_param::TypeParam, FunctionType, PolyFuncType, SumType, Type, TypeArg, TypeBound, @@ -351,8 +351,11 @@ fn hierarchy_order() -> Result<(), Box> { #[test] fn constants_roundtrip() -> Result<(), Box> { - let mut builder = - DFGBuilder::new(FunctionType::new(vec![], vec![INT_TYPES[4].clone()])).unwrap(); + let mut builder = DFGBuilder::new( + FunctionType::new(vec![], vec![INT_TYPES[4].clone()]) + .with_extension_delta(int_types::EXTENSION_ID), + ) + .unwrap(); let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap()); let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?; diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 00a8cae7c..ce6532ab0 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,16 +9,11 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::validate::ExtensionValidator; -use crate::extension::SignatureError; -use crate::extension::{ - validate::ExtensionError, ExtensionRegistry, ExtensionSolution, InferExtensionError, -}; - -use crate::ops::custom::CustomOpError; -use crate::ops::custom::{resolve_opaque_op, CustomOp}; +use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; + +use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, Hugr, Node, Port}; @@ -45,10 +40,9 @@ impl Hugr { /// TODO: Add a version of validation which allows for open extension /// variables (see github issue #457) pub fn validate(&self, extension_registry: &ExtensionRegistry) -> Result<(), ValidationError> { - #[cfg(feature = "extension_inference")] - self.validate_with_extension_closure(HashMap::new(), extension_registry)?; - #[cfg(not(feature = "extension_inference"))] self.validate_no_extensions(extension_registry)?; + #[cfg(feature = "extension_inference")] + self.validate_extensions()?; Ok(()) } @@ -62,48 +56,39 @@ impl Hugr { validator.validate() } - /// Validate extensions on the input and output edges of nodes. Check that - /// the target ends of edges require the extensions from the sources, and - /// check extension deltas from parent nodes are reflected in their children. - pub fn validate_extensions(&self, closure: ExtensionSolution) -> Result<(), ValidationError> { - let validator = ExtensionValidator::new(self, closure); - for src_node in self.nodes() { - let node_type = self.get_nodetype(src_node); - - // FuncDefns have no resources since they're static nodes, but the - // functions they define can have any extension delta. - if node_type.tag() != OpTag::FuncDefn { - // If this is a container with I/O nodes, check that the extension they - // define match the extensions of the container. - if let Some([input, output]) = self.get_io(src_node) { - validator.validate_io_extensions(src_node, input, output)?; - } - } - - for src_port in self.node_outputs(src_node) { - for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) { - validator.check_extensions_compatible( - &(src_node, src_port.into()), - &(tgt_node, tgt_port.into()), - )?; + /// Validate extensions, i.e. that extension deltas from parent nodes are reflected in their children. + pub fn validate_extensions(&self) -> Result<(), ValidationError> { + for parent in self.nodes() { + let parent_op = self.get_optype(parent); + let parent_extensions = match parent_op.inner_function_type() { + Some(FunctionType { extension_reqs, .. }) => extension_reqs, + None => match parent_op.tag() { + OpTag::Cfg | OpTag::Conditional => parent_op.extension_delta(), + // ModuleRoot holds but does not execute its children, so allow any extensions + OpTag::ModuleRoot => continue, + _ => { + assert!(self.children(parent).next().is_none(), + "Unknown parent node type {:?} - not a DataflowParent, Module, Cfg or Conditional", + parent_op); + continue; + } + }, + }; + for child in self.children(parent) { + let child_extensions = self.get_optype(child).extension_delta(); + if !parent_extensions.is_superset(&child_extensions) { + return Err(ExtensionError { + parent, + parent_extensions, + child, + child_extensions, + } + .into()); } } } Ok(()) } - - /// Check the validity of a hugr, taking an argument of a closure for the - /// free extension variables - pub fn validate_with_extension_closure( - &self, - closure: ExtensionSolution, - extension_registry: &ExtensionRegistry, - ) -> Result<(), ValidationError> { - let mut validator = ValidationContext::new(self, extension_registry); - validator.validate()?; - self.validate_extensions(closure)?; - Ok(()) - } } impl<'a, 'b> ValidationContext<'a, 'b> { @@ -758,11 +743,9 @@ pub enum ValidationError { /// There are invalid inter-graph edges. #[error(transparent)] InterGraphEdgeError(#[from] InterGraphEdgeError), - /// There are errors in the extension declarations. + /// There are errors in the extension deltas. #[error(transparent)] ExtensionError(#[from] ExtensionError), - #[error(transparent)] - CantInfer(#[from] InferExtensionError), /// Error in a node signature #[error("Error in signature of node {node:?}: {cause}")] SignatureError { node: Node, cause: SignatureError }, @@ -835,5 +818,15 @@ pub enum InterGraphEdgeError { }, } +#[derive(Debug, Clone, PartialEq, Error)] +#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")] +/// An error in the extension deltas. +pub struct ExtensionError { + parent: Node, + parent_extensions: ExtensionSet, + child: Node, + child_extensions: ExtensionSet, +} + #[cfg(test)] pub(crate) mod test; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 8e8e1f0c5..47c197416 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -7,7 +7,7 @@ use crate::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; -use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T}; +use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, QB_T, USIZE_T}; use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; @@ -152,7 +152,6 @@ fn children_restrictions() { b.update_validate(&EMPTY_REG), Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) ); - b.infer_extensions().unwrap(); b.set_parent(new_def, root); // After moving the previous definition to a valid place, @@ -865,145 +864,178 @@ fn test_polymorphic_load() -> Result<(), Box> { Ok(()) } -#[cfg(feature = "extension_inference")] -mod extension_tests { - use super::*; - use crate::extension::ExtensionSet; - use crate::macros::const_extension_ids; - - const_extension_ids! { - const XA: ExtensionId = "A"; - const XB: ExtensionId = "BOOL_EXT"; - } - - const Q: Type = crate::extension::prelude::QB_T; - - /// Adds an input{BOOL_T}, tag_constant(0, BOOL_T^sum_size), tag(BOOL_T^sum_size), and - /// output{Sum{unit^sum_size}, BOOL_T} operation to a dataflow container. - /// Intended to be used to populate a BasicBlock node in a CFG. - /// - /// Returns the node indices of each of the operations. - fn add_block_children(b: &mut Hugr, parent: Node, sum_size: usize) -> (Node, Node, Node, Node) { - let const_op: ops::Const = ops::Value::unit_sum(0, sum_size as u8) - .expect("`sum_size` must be greater than 0") - .into(); - let tag_type = Type::new_unit_sum(sum_size as u8); - - let input = b.add_node_with_parent(parent, ops::Input::new(type_row![BOOL_T])); +#[test] +/// Validation errors in a controlflow subgraph. +fn cfg_children_restrictions() { + let (mut b, def) = make_simple_hugr(1); + let (_input, _output, copy) = b + .hierarchy + .children(def.pg_index()) + .map_into() + .collect_tuple() + .unwrap(); + // Write Extension annotations into the Hugr while it's still well-formed + // enough for us to compute them + b.validate(&EMPTY_REG).unwrap(); + b.replace_op( + copy, + NodeType::new_pure(ops::CFG { + signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), + }), + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::ContainerWithoutChildren { .. }) + ); + let cfg = copy; + + // Construct a valid CFG, with one BasicBlock node and one exit node + let block = b.add_node_with_parent( + cfg, + ops::DataflowBlock { + inputs: type_row![BOOL_T], + sum_rows: vec![type_row![]], + other_outputs: type_row![BOOL_T], + extension_delta: ExtensionSet::new(), + }, + ); + let const_op: ops::Const = ops::Value::unit_sum(0, 1).unwrap().into(); + let tag_type = Type::new_unit_sum(1); + { + let input = b.add_node_with_parent(block, ops::Input::new(type_row![BOOL_T])); let output = - b.add_node_with_parent(parent, ops::Output::new(vec![tag_type.clone(), BOOL_T])); + b.add_node_with_parent(block, ops::Output::new(vec![tag_type.clone(), BOOL_T])); let tag_def = b.add_node_with_parent(b.root(), const_op); - let tag = b.add_node_with_parent(parent, ops::LoadConstant { datatype: tag_type }); + let tag = b.add_node_with_parent(block, ops::LoadConstant { datatype: tag_type }); b.connect(tag_def, 0, tag, 0); b.add_other_edge(input, tag); b.connect(tag, 0, output, 0); b.connect(input, 0, output, 1); - - (input, tag_def, tag, output) } + let exit = b.add_node_with_parent( + cfg, + ops::ExitBlock { + cfg_outputs: type_row![BOOL_T], + }, + ); + b.add_other_edge(block, exit); + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); - #[test] - /// Validation errors in a dataflow subgraph. - fn cfg_children_restrictions() { - let (mut b, def) = make_simple_hugr(1); - let (_input, _output, copy) = b - .hierarchy - .children(def.pg_index()) - .map_into() - .collect_tuple() - .unwrap(); - // Write Extension annotations into the Hugr while it's still well-formed - // enough for us to compute them - b.infer_extensions().unwrap(); - b.validate(&EMPTY_REG).unwrap(); - b.replace_op( - copy, - NodeType::new_pure(ops::CFG { - signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), - }), - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::ContainerWithoutChildren { .. }) - ); - let cfg = copy; - - // Construct a valid CFG, with one BasicBlock node and one exit node - let block = b.add_node_with_parent( - cfg, - ops::DataflowBlock { - inputs: type_row![BOOL_T], - sum_rows: vec![type_row![]], - other_outputs: type_row![BOOL_T], - extension_delta: ExtensionSet::new(), - }, - ); - add_block_children(&mut b, block, 1); - let exit = b.add_node_with_parent( - cfg, - ops::ExitBlock { - cfg_outputs: type_row![BOOL_T], - }, - ); - b.add_other_edge(block, exit); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + // Test malformed errors - // Test malformed errors + // Add an internal exit node + let exit2 = b.add_node_after( + exit, + ops::ExitBlock { + cfg_outputs: type_row![BOOL_T], + }, + ); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) + => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} + ); + b.remove_node(exit2); - // Add an internal exit node - let exit2 = b.add_node_after( - exit, - ops::ExitBlock { - cfg_outputs: type_row![BOOL_T], - }, - ); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) - => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} - ); - b.remove_node(exit2); - - // Change the types in the BasicBlock node to work on qubits instead of bits - b.replace_op( - block, - NodeType::new_pure(ops::DataflowBlock { - inputs: type_row![Q], - sum_rows: vec![type_row![]], - other_outputs: type_row![Q], - extension_delta: ExtensionSet::new(), - }), - ) - .unwrap(); - let mut block_children = b.hierarchy.children(block.pg_index()); - let block_input = block_children.next().unwrap().into(); - let block_output = block_children.next_back().unwrap().into(); - b.replace_op( - block_input, - NodeType::new_pure(ops::Input::new(type_row![Q])), - ) - .unwrap(); - b.replace_op( - block_output, - NodeType::new_pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) - => assert_eq!(parent, cfg) - ); + // Change the types in the BasicBlock node to work on qubits instead of bits + b.replace_op( + block, + NodeType::new_pure(ops::DataflowBlock { + inputs: type_row![QB_T], + sum_rows: vec![type_row![]], + other_outputs: type_row![QB_T], + extension_delta: ExtensionSet::new(), + }), + ) + .unwrap(); + let mut block_children = b.hierarchy.children(block.pg_index()); + let block_input = block_children.next().unwrap().into(); + let block_output = block_children.next_back().unwrap().into(); + b.replace_op( + block_input, + NodeType::new_pure(ops::Input::new(type_row![QB_T])), + ) + .unwrap(); + b.replace_op( + block_output, + NodeType::new_pure(ops::Output::new(type_row![Type::new_unit_sum(1), QB_T])), + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) + => assert_eq!(parent, cfg) + ); +} + +#[test] +// /->->\ +// | | +// Entry -> Middle -> Exit +fn cfg_connections() -> Result<(), Box> { + use crate::builder::CFGBuilder; + + let mut hugr = CFGBuilder::new(FunctionType::new_endo(USIZE_T))?; + let unary_pred = hugr.add_constant(Value::unary_unit_sum()); + let mut entry = hugr.simple_entry_builder(type_row![USIZE_T], 1, ExtensionSet::new())?; + let p = entry.load_const(&unary_pred); + let ins = entry.input_wires(); + let entry = entry.finish_with_outputs(p, ins)?; + + let mut middle = hugr.simple_block_builder(FunctionType::new_endo(USIZE_T), 1)?; + let p = middle.load_const(&unary_pred); + let ins = middle.input_wires(); + let middle = middle.finish_with_outputs(p, ins)?; + + let exit = hugr.exit_block(); + hugr.branch(&entry, 0, &middle)?; + hugr.branch(&middle, 0, &exit)?; + let mut h = hugr.finish_hugr(&PRELUDE_REGISTRY)?; + + h.connect(middle.node(), 0, middle.node(), 0); + assert_eq!( + h.validate(&PRELUDE_REGISTRY), + Err(ValidationError::TooManyConnections { + node: middle.node(), + port: Port::new(Direction::Outgoing, 0), + port_kind: EdgeKind::ControlFlow + }) + ); + Ok(()) +} + +#[cfg(feature = "extension_inference")] +mod extension_tests { + use self::ops::handle::{BasicBlockID, TailLoopID}; + + use super::*; + use crate::builder::handle::Outputs; + use crate::builder::{BlockBuilder, BuildHandle, CFGBuilder, DFGWrapper, TailLoopBuilder}; + use crate::extension::ExtensionSet; + use crate::macros::const_extension_ids; + use crate::Wire; + + const_extension_ids! { + const XA: ExtensionId = "A"; + const XB: ExtensionId = "BOOL_EXT"; } - #[test] - fn parent_io_mismatch() { - // The DFG node declares that it has an empty extension delta, - // but it's child graph adds extension "XB", causing a mismatch. - let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]), - })); + #[rstest] + #[case::d1(|signature| ops::DFG {signature}.into())] + #[case::f1(|ft: FunctionType| ops::FuncDefn {name: "foo".to_string(), signature: ft.into()}.into())] + #[case::c1(|signature| ops::Case {signature}.into())] + fn parent_extension_mismatch( + #[case] parent_f: impl Fn(FunctionType) -> OpType, + #[values(ExtensionSet::new(), XA.into())] parent_extensions: ExtensionSet, + ) { + // Child graph adds extension "XB", but the parent (in all cases) + // declares a different delta, causing a mismatch. + let parent = parent_f( + FunctionType::new_endo(USIZE_T).with_extension_delta(parent_extensions.clone()), + ); + let mut hugr = Hugr::new(NodeType::new_pure(parent)); let input = hugr.add_node_with_parent( hugr.root(), @@ -1033,162 +1065,173 @@ mod extension_tests { hugr.connect(lift, 0, output, 0); let result = hugr.validate(&PRELUDE_REGISTRY); - assert_matches!( + assert_eq!( result, - Err(ValidationError::ExtensionError( - ExtensionError::ParentIOExtensionMismatch { .. } - )) + Err(ValidationError::ExtensionError(ExtensionError { + parent: hugr.root(), + parent_extensions, + child: lift, + child_extensions: XB.into() + })) ); } - #[test] - /// A wire with no extension requirements is wired into a node which has - /// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed - /// by adding a lift node, but for validation this is an error. - fn missing_lift_node() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - let mut main = module_builder.define_function( - "main", - FunctionType::new(type_row![NAT], type_row![NAT]).into(), - )?; - let [main_input] = main.input_wires_arr(); - - let f_builder = main.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), - // Inner DFG has extension requirements that the wire wont satisfy - Some(ExtensionSet::from_iter([XA, XB])), - [main_input], + #[rstest] + #[case(XA.into(), false)] + #[case(ExtensionSet::new(), false)] + #[case(ExtensionSet::from_iter([XA, XB]), true)] + fn cfg_extension_mismatch( + #[case] parent_extensions: ExtensionSet, + #[case] success: bool, + ) -> Result<(), BuildError> { + let mut cfg = CFGBuilder::new( + FunctionType::new_endo(USIZE_T).with_extension_delta(parent_extensions.clone()), )?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); + let mut bb = cfg.simple_entry_builder(USIZE_T.into(), 1, XB.into())?; + let pred = bb.add_load_value(Value::unary_unit_sum()); + let inputs = bb.input_wires(); + let blk = bb.finish_with_outputs(pred, inputs)?; + let exit = cfg.exit_block(); + cfg.branch(&blk, 0, &exit)?; + let root = cfg.hugr().root(); + let res = cfg.finish_prelude_hugr(); + if success { + assert!(res.is_ok()) + } else { + assert_eq!( + res, + Err(ValidationError::ExtensionError(ExtensionError { + parent: root, + parent_extensions, + child: blk.node(), + child_extensions: XB.into() + })) + ); + } Ok(()) } - #[test] - /// A wire with extension requirement `[A]` is wired into a an output with no - /// extension req. In the validation extension typechecking, we don't do any - /// unification, so don't allow open extension variables on the function - /// signature, so this fails. - fn too_many_extension() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).into(); - - let mut main = module_builder.define_function("main", main_sig)?; - let [main_input] = main.input_wires_arr(); - - let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(XA); + #[rstest] + #[case(XA.into(), false)] + #[case(ExtensionSet::new(), false)] + #[case(ExtensionSet::from_iter([XA, XB]), true)] + fn conditional_extension_mismatch( + #[case] parent_extensions: ExtensionSet, + #[case] success: bool, + ) { + // Child graph adds extension "XB", but the parent + // declares a different delta, in same cases causing a mismatch. + let parent = ops::Conditional { + sum_rows: vec![type_row![], type_row![]], + other_inputs: type_row![USIZE_T], + outputs: type_row![USIZE_T], + extension_delta: parent_extensions.clone(), + }; + let mut hugr = Hugr::new(NodeType::new_pure(parent)); + + // First case with no delta should be ok in all cases. Second one may not be. + let [_, child] = [None, Some(XB)].map(|case_ext| { + let case_exts = ExtensionSet::from_iter(case_ext.clone()); + let case = hugr.add_node_with_parent( + hugr.root(), + ops::Case { + signature: FunctionType::new_endo(USIZE_T) + .with_extension_delta(case_exts.clone()), + }, + ); - let f_builder = main.dfg_builder(inner_sig, Some(ExtensionSet::new()), [main_input])?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::SrcExceedsTgtExtensionsAtPort { .. } - )) - ); - Ok(()) + let input = hugr.add_node_with_parent( + case, + NodeType::new_pure(ops::Input { + types: type_row![USIZE_T], + }), + ); + let output = hugr.add_node_with_parent( + case, + NodeType::new( + ops::Output { + types: type_row![USIZE_T], + }, + Some(case_exts), + ), + ); + let res = match case_ext { + None => input, + Some(new_ext) => { + let lift = hugr.add_node_with_parent( + case, + NodeType::new_pure(ops::Lift { + type_row: type_row![USIZE_T], + new_extension: new_ext, + }), + ); + hugr.connect(input, 0, lift, 0); + lift + } + }; + hugr.connect(res, 0, output, 0); + case + }); + // case is the last-assigned child, i.e. the one that requires 'XB' + let result = hugr.validate(&PRELUDE_REGISTRY); + let expected = if success { + Ok(()) + } else { + Err(ValidationError::ExtensionError(ExtensionError { + parent: hugr.root(), + parent_extensions, + child, + child_extensions: XB.into(), + })) + }; + assert_eq!(result, expected); } - #[test] - /// A wire with extension requirements `[A]` and another with requirements - /// `[BOOL_T]` are both wired into a node which requires its inputs to have - /// requirements `[A,BOOL_T]`. A slightly more complex test of the error from - /// `missing_lift_node`. - fn extensions_mismatch() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let all_rs = ExtensionSet::from_iter([XA, XB]); - - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(all_rs.clone()) - .into(); - - let mut main = module_builder.define_function("main", main_sig)?; - - let [inp_wire] = main.input_wires_arr(); - - let [left_wire] = main - .dfg_builder( - FunctionType::new(type_row![], type_row![NAT]), - Some(XA.into()), - [], - )? - .finish_with_outputs([inp_wire])? - .outputs_arr(); - - let [right_wire] = main - .dfg_builder( - FunctionType::new(type_row![], type_row![NAT]), - Some(XB.into()), - [], - )? - .finish_with_outputs([inp_wire])? - .outputs_arr(); - - let builder = main.dfg_builder( - FunctionType::new(type_row![NAT, NAT], type_row![NAT]), - Some(all_rs), - [left_wire, right_wire], + #[rstest] + #[case(make_bb, |bb: &mut DFGWrapper<_,_>, outs| bb.make_tuple(outs))] + #[case(make_tailloop, |tl: &mut DFGWrapper<_,_>, outs| tl.make_break(tl.loop_signature().unwrap().clone(), outs))] + fn bb_extension_mismatch( + #[case] dfg_fn: impl Fn(Type, ExtensionSet) -> DFGWrapper, + #[case] make_pred: impl Fn(&mut DFGWrapper, Outputs) -> Result, + #[values((XA.into(), false), (ExtensionSet::new(), false), (ExtensionSet::from_iter([XA,XB]), true))] + parent_exts_success: (ExtensionSet, bool), + ) -> Result<(), BuildError> { + let (parent_extensions, success) = parent_exts_success; + let mut dfg = dfg_fn(USIZE_T, parent_extensions.clone()); + let lift = dfg.add_dataflow_op( + ops::Lift { + type_row: USIZE_T.into(), + new_extension: XB, + }, + dfg.input_wires(), )?; - let [left, _] = builder.input_wires_arr(); - let [output] = builder.finish_with_outputs([left])?.outputs_arr(); - - main.finish_with_outputs([output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); + let pred = make_pred(&mut dfg, lift.outputs())?; + let root = dfg.hugr().root(); + let res = dfg.finish_prelude_hugr_with_outputs([pred]); + if success { + assert!(res.is_ok()) + } else { + assert_eq!( + res, + Err(BuildError::InvalidHUGR(ValidationError::ExtensionError( + ExtensionError { + parent: root, + parent_extensions, + child: lift.node(), + child_extensions: XB.into() + } + ))) + ); + } Ok(()) } - #[test] - fn parent_signature_mismatch() { - let main_signature = - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(XA); - - let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: main_signature, - })); - let input = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Input { - types: type_row![NAT], - }), - ); - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::new( - ops::Output { - types: type_row![NAT], - }, - Some(XA.into()), - ), - ); - hugr.connect(input, 0, output, 0); + fn make_bb(t: Type, es: ExtensionSet) -> DFGWrapper { + BlockBuilder::new(t.clone(), None, vec![t.into()], type_row![], es).unwrap() + } - assert_matches!( - hugr.validate(&PRELUDE_REGISTRY), - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); + fn make_tailloop(t: Type, es: ExtensionSet) -> DFGWrapper> { + let row = TypeRow::from(t); + TailLoopBuilder::new(row.clone(), type_row![], row, es).unwrap() } } diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 2e48d1f9d..76f3e54a4 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -432,7 +432,6 @@ impl OpParent for MakeTuple {} impl OpParent for UnpackTuple {} impl OpParent for Tag {} impl OpParent for Lift {} -impl OpParent for TailLoop {} impl OpParent for CFG {} impl OpParent for Conditional {} impl OpParent for FuncDecl {} diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 5f059f4f0..ea55908fa 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -433,6 +433,7 @@ pub type ValueNameRef = str; mod test { use super::Value; use crate::builder::test::simple_dfg_hugr; + use crate::extension::prelude::PRELUDE_ID; use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, @@ -492,10 +493,13 @@ mod test { let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], Type::EMPTY_TYPEROW]; let pred_ty = SumType::new(pred_rows.clone()); - let mut b = DFGBuilder::new(FunctionType::new( - type_row![], - TypeRow::from(vec![pred_ty.clone().into()]), - ))?; + let mut b = DFGBuilder::new( + FunctionType::new(type_row![], TypeRow::from(vec![pred_ty.clone().into()])) + .with_extension_delta(ExtensionSet::from_iter([ + float_types::EXTENSION_ID, + PRELUDE_ID, + ])), + )?; let c = b.add_constant(Value::sum( 0, [ diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index c09333ee3..b1a6b0e37 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -18,6 +18,8 @@ pub struct TailLoop { pub just_outputs: TypeRow, /// Types that are appended to both input and output pub rest: TypeRow, + /// Extension requirements to execute the body + pub extension_delta: ExtensionSet, } impl_op_name!(TailLoop); @@ -32,7 +34,7 @@ impl DataflowOpTrait for TailLoop { fn signature(&self) -> FunctionType { let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter())); - FunctionType::new(inputs, outputs) + FunctionType::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()) } } @@ -51,6 +53,13 @@ impl TailLoop { } } +impl DataflowParent for TailLoop { + fn inner_signature(&self) -> FunctionType { + FunctionType::new(self.body_input_row(), self.body_output_row()) + .with_extension_delta(self.extension_delta.clone()) + } +} + /// Conditional operation, defined by child `Case` nodes for each branch. #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] @@ -159,6 +168,7 @@ impl DataflowParent for DataflowBlock { let mut node_outputs = vec![sum_type]; node_outputs.extend_from_slice(&self.other_outputs); FunctionType::new(self.inputs.clone(), TypeRow::from(node_outputs)) + .with_extension_delta(self.extension_delta.clone()) } } diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index c87034d60..aaee46049 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -106,31 +106,6 @@ impl ValidateOp for super::Conditional { } } -impl ValidateOp for super::TailLoop { - fn validity_flags(&self) -> OpValidityFlags { - OpValidityFlags { - allowed_children: OpTag::DataflowChild, - allowed_first_child: OpTag::Input, - allowed_second_child: OpTag::Output, - requires_children: true, - requires_dag: true, - ..Default::default() - } - } - - fn validate_op_children<'a>( - &self, - children: impl DoubleEndedIterator, - ) -> Result<(), ChildrenValidationError> { - validate_io_nodes( - &self.body_input_row(), - &self.body_output_row(), - "tail-controlled loop graph", - children, - ) - } -} - impl ValidateOp for super::CFG { fn validity_flags(&self) -> OpValidityFlags { OpValidityFlags { diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index b882699d5..88720d41e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -2,6 +2,7 @@ use std::collections::{BTreeSet, HashMap}; +use hugr_core::extension::ExtensionSet; use itertools::Itertools; use thiserror::Error; @@ -136,7 +137,10 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR /// against `reg`. fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { let const_types = consts.iter().map(Value::get_type).collect_vec(); - let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); + let exts = ExtensionSet::union_over(consts.iter().map(Value::extension_reqs)); + let mut b = + DFGBuilder::new(FunctionType::new(type_row![], const_types).with_extension_delta(exts)) + .unwrap(); let outputs = consts .into_iter() diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 3f5eb2ba4..b5c0913bd 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -3,9 +3,9 @@ use hugr_core::builder::{DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{sum_with_error, ConstError, ConstString, BOOL_T, STRING_TYPE}; use hugr_core::extension::{ExtensionRegistry, PRELUDE}; use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; +use hugr_core::std_extensions::arithmetic::{self, float_types, int_types}; use hugr_core::std_extensions::logic::{self, NaryLogic, NotOp}; use hugr_core::type_row; use hugr_core::types::{FunctionType, Type, TypeRow}; @@ -72,6 +72,11 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { assert_eq!(outs.as_slice(), &[(0.into(), c)]); } + +fn float_fn(outputs: impl Into) -> FunctionType { + FunctionType::new(type_row![], outputs).with_extension_delta(float_types::EXTENSION_ID) +} + #[test] fn test_big() { /* @@ -80,11 +85,7 @@ fn test_big() { int(x.0 - x.1) == 2 */ let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(float_fn(vec![sum_type.clone().into()])).unwrap(); let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); @@ -273,7 +274,7 @@ fn test_folding_pass_issue_996() { // x6 := flt(x0, x5) // false // x7 := or(x4, x6) // true // output x7 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(float_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); @@ -313,7 +314,7 @@ fn test_const_fold_to_nonfinite() { .unwrap(); // HUGR computing 1.0 / 1.0 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let mut build = DFGBuilder::new(float_fn(vec![FLOAT64_TYPE])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); @@ -325,7 +326,7 @@ fn test_const_fold_to_nonfinite() { assert_eq!(h0.node_count(), 5); // HUGR computing 1.0 / 0.0 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let mut build = DFGBuilder::new(float_fn(vec![FLOAT64_TYPE])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); @@ -334,6 +335,10 @@ fn test_const_fold_to_nonfinite() { assert_eq!(h1.node_count(), 8); } +fn int_fn(outputs: impl Into) -> FunctionType { + FunctionType::new(type_row![], outputs).with_extension_delta(int_types::EXTENSION_ID) +} + #[test] fn test_fold_iwiden_u() { // pseudocode: @@ -341,8 +346,7 @@ fn test_fold_iwiden_u() { // x0 := int_u<4>(13); // x1 := iwiden_u<4, 5>(x0); // output x1 == int_u<5>(13); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(4, 13).unwrap())); let x1 = build .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) @@ -365,8 +369,7 @@ fn test_fold_iwiden_s() { // x0 := int_u<4>(-3); // x1 := iwiden_u<4, 5>(x0); // output x1 == int_s<5>(-3); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(4, -3).unwrap())); let x1 = build .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) @@ -411,11 +414,7 @@ fn test_fold_inarrow, E: std::fmt::Debug>( // succeeds => whether to expect a int variant or an error // variant. let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into()); let x1 = build .add_dataflow_op( @@ -452,7 +451,7 @@ fn test_fold_itobool() { // x0 := int_u<0>(1); // x1 := itobool(x0); // output x1 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(0, 1).unwrap())); let x1 = build .add_dataflow_op(IntOpDef::itobool.without_log_width(), [x0]) @@ -498,7 +497,7 @@ fn test_fold_ieq() { // x0, x1 := int_s<3>(-1), int_u<3>(255) // x2 := ieq(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(3, -1).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 255).unwrap())); let x2 = build @@ -521,7 +520,7 @@ fn test_fold_ine() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ine(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build @@ -544,7 +543,7 @@ fn test_fold_ilt_u() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ilt_u(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build @@ -567,7 +566,7 @@ fn test_fold_ilt_s() { // x0, x1 := int_s<5>(3), int_s<5>(-4) // x2 := ilt_s(x0, x1) // output x2 == false; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build @@ -590,7 +589,7 @@ fn test_fold_igt_u() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ilt_u(x0, x1) // output x2 == false; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build @@ -613,7 +612,7 @@ fn test_fold_igt_s() { // x0, x1 := int_s<5>(3), int_s<5>(-4) // x2 := ilt_s(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build @@ -636,7 +635,7 @@ fn test_fold_ile_u() { // x0, x1 := int_u<5>(3), int_u<5>(3) // x2 := ile_u(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x2 = build @@ -659,7 +658,7 @@ fn test_fold_ile_s() { // x0, x1 := int_s<5>(-4), int_s<5>(-4) // x2 := ile_s(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build @@ -682,7 +681,7 @@ fn test_fold_ige_u() { // x0, x1 := int_u<5>(3), int_u<5>(4) // x2 := ilt_u(x0, x1) // output x2 == false; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build @@ -705,7 +704,7 @@ fn test_fold_ige_s() { // x0, x1 := int_s<5>(3), int_s<5>(-4) // x2 := ilt_s(x0, x1) // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); let x2 = build @@ -728,8 +727,7 @@ fn test_fold_imax_u() { // x0, x1 := int_u<5>(7), int_u<5>(11); // x2 := imax_u(x0, x1); // output x2 == int_u<5>(11); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); let x2 = build @@ -752,8 +750,7 @@ fn test_fold_imax_s() { // x0, x1 := int_s<5>(-2), int_s<5>(1); // x2 := imax_u(x0, x1); // output x2 == int_s<5>(1); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); let x2 = build @@ -776,8 +773,7 @@ fn test_fold_imin_u() { // x0, x1 := int_u<5>(7), int_u<5>(11); // x2 := imin_u(x0, x1); // output x2 == int_u<5>(7); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); let x2 = build @@ -800,8 +796,7 @@ fn test_fold_imin_s() { // x0, x1 := int_s<5>(-2), int_s<5>(1); // x2 := imin_u(x0, x1); // output x2 == int_s<5>(-2); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); let x2 = build @@ -824,8 +819,7 @@ fn test_fold_iadd() { // x0, x1 := int_s<5>(-2), int_s<5>(1); // x2 := iadd(x0, x1); // output x2 == int_s<5>(-1); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); let x2 = build @@ -848,8 +842,7 @@ fn test_fold_isub() { // x0, x1 := int_s<5>(-2), int_s<5>(1); // x2 := isub(x0, x1); // output x2 == int_s<5>(-3); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); let x2 = build @@ -872,8 +865,7 @@ fn test_fold_ineg() { // x0 := int_s<5>(-2); // x1 := ineg(x0); // output x1 == int_s<5>(2); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) @@ -895,8 +887,7 @@ fn test_fold_imul() { // x0, x1 := int_s<5>(-2), int_s<5>(7); // x2 := imul(x0, x1); // output x2 == int_s<5>(-14); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 7).unwrap())); let x2 = build @@ -921,11 +912,7 @@ fn test_fold_idivmod_checked_u() { // output x2 == error let intpair: TypeRow = vec![INT_TYPES[5].clone(), INT_TYPES[3].clone()].into(); let sum_type = sum_with_error(Type::new_tuple(intpair)); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); let x2 = build @@ -962,8 +949,7 @@ fn test_fold_idivmod_u() { // x4 := iwiden_u<3,5>(x3); // 2 // x5 := iadd<5>(x2, x4); // 8 // output x5 == int_u<5>(8); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let [x2, x3] = build @@ -996,11 +982,7 @@ fn test_fold_idivmod_checked_s() { // output x2 == error let intpair: TypeRow = vec![INT_TYPES[5].clone(), INT_TYPES[3].clone()].into(); let sum_type = sum_with_error(Type::new_tuple(intpair)); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); let x2 = build @@ -1038,8 +1020,7 @@ fn test_fold_idivmod_checked_s() { #[case(i64::MIN, 1u64 << 63, -1)] // c = a/b + a%b fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[6].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[6].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(6, a).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(6, b).unwrap())); let [x2, x3] = build @@ -1067,11 +1048,7 @@ fn test_fold_idiv_checked_u() { // x2 := idiv_checked_u(x0, x1) // output x2 == error let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); let x2 = build @@ -1103,8 +1080,7 @@ fn test_fold_idiv_u() { // x0, x1 := int_u<5>(20), int_u<3>(3); // x2 := idiv_u(x0, x1); // output x2 == int_u<5>(6); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1128,11 +1104,7 @@ fn test_fold_imod_checked_u() { // x2 := imod_checked_u(x0, x1) // output x2 == error let sum_type = sum_with_error(INT_TYPES[3].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); let x2 = build @@ -1164,8 +1136,7 @@ fn test_fold_imod_u() { // x0, x1 := int_u<5>(20), int_u<3>(3); // x2 := imod_u(x0, x1); // output x2 == int_u<3>(2); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[3].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[3].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1189,11 +1160,7 @@ fn test_fold_idiv_checked_s() { // x2 := idiv_checked_s(x0, x1) // output x2 == error let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); let x2 = build @@ -1225,8 +1192,7 @@ fn test_fold_idiv_s() { // x0, x1 := int_s<5>(-20), int_u<3>(3); // x2 := idiv_s(x0, x1); // output x2 == int_s<5>(-7); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1250,11 +1216,7 @@ fn test_fold_imod_checked_s() { // x2 := imod_checked_u(x0, x1) // output x2 == error let sum_type = sum_with_error(INT_TYPES[3].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![sum_type.clone().into()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); let x2 = build @@ -1286,8 +1248,7 @@ fn test_fold_imod_s() { // x0, x1 := int_s<5>(-20), int_u<3>(3); // x2 := imod_s(x0, x1); // output x2 == int_u<3>(1); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[3].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[3].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1310,8 +1271,7 @@ fn test_fold_iabs() { // x0 := int_s<5>(-2); // x1 := iabs(x0); // output x1 == int_s<5>(2); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) @@ -1333,8 +1293,7 @@ fn test_fold_iand() { // x0, x1 := int_u<5>(14), int_u<5>(20); // x2 := iand(x0, x1); // output x2 == int_u<5>(4); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x2 = build @@ -1357,8 +1316,7 @@ fn test_fold_ior() { // x0, x1 := int_u<5>(14), int_u<5>(20); // x2 := ior(x0, x1); // output x2 == int_u<5>(30); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x2 = build @@ -1381,8 +1339,7 @@ fn test_fold_ixor() { // x0, x1 := int_u<5>(14), int_u<5>(20); // x2 := ixor(x0, x1); // output x2 == int_u<5>(26); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); let x2 = build @@ -1405,8 +1362,7 @@ fn test_fold_inot() { // x0 := int_u<5>(14); // x1 := inot(x0); // output x1 == int_u<5>(17); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); let x2 = build .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) @@ -1428,8 +1384,7 @@ fn test_fold_ishl() { // x0, x1 := int_u<5>(14), int_u<3>(3); // x2 := ishl(x0, x1); // output x2 == int_u<5>(112); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1452,8 +1407,7 @@ fn test_fold_ishr() { // x0, x1 := int_u<5>(14), int_u<3>(3); // x2 := ishr(x0, x1); // output x2 == int_u<5>(1); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1476,8 +1430,7 @@ fn test_fold_irotl() { // x0, x1 := int_u<5>(14), int_u<3>(61); // x2 := irotl(x0, x1); // output x2 == int_u<5>(2^30 + 2^31 + 1); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 61).unwrap())); let x2 = build @@ -1500,8 +1453,7 @@ fn test_fold_irotr() { // x0, x1 := int_u<5>(14), int_u<3>(3); // x2 := irotr(x0, x1); // output x2 == int_u<5>(2^30 + 2^31 + 1); - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![INT_TYPES[5].clone()])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); let x2 = build @@ -1524,7 +1476,7 @@ fn test_fold_itostring_u() { // x0 := int_u<5>(17); // x1 := itostring_u(x0); // output x2 := "17"; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![STRING_TYPE])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![STRING_TYPE])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 17).unwrap())); let x1 = build .add_dataflow_op(IntOpDef::itostring_u.with_log_width(5), [x0]) @@ -1546,7 +1498,7 @@ fn test_fold_itostring_s() { // x0 := int_s<5>(-17); // x1 := itostring_s(x0); // output x2 := "-17"; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![STRING_TYPE])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![STRING_TYPE])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -17).unwrap())); let x1 = build .add_dataflow_op(IntOpDef::itostring_s.with_log_width(5), [x0]) @@ -1575,7 +1527,7 @@ fn test_fold_int_ops() { // x6 := ilt_s(x0, x5) // false // x7 := or(x4, x6) // true // output x7 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let mut build = DFGBuilder::new(int_fn(vec![BOOL_T])).unwrap(); let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); let x2 = build diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 7d1cf9d52..399e86055 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -227,14 +227,14 @@ mod test { FunctionType::new(loop_variants.clone(), exit_types.clone()) .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), )?; - let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, ExtensionSet::new())?; + let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, PRELUDE_ID.into())?; let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; let mut test_block = h.block_builder( loop_variants.clone(), vec![loop_variants.clone(), exit_types], - ExtensionSet::singleton(&PRELUDE_ID), + PRELUDE_ID.into(), type_row![], )?; let [test_input] = test_block.input_wires_arr(); @@ -246,7 +246,10 @@ mod test { let loop_backedge_target = if self_loop { no_b1 } else { - let mut no_b2 = h.simple_block_builder(FunctionType::new_endo(loop_variants), 1)?; + let mut no_b2 = h.simple_block_builder( + FunctionType::new_endo(loop_variants).with_extension_delta(PRELUDE_ID), + 1, + )?; let n = no_b2.add_dataflow_op(Noop::new(QB_T), no_b2.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b2); let nid = no_b2.finish_with_outputs(br, n.outputs())?; @@ -339,7 +342,7 @@ mod test { let mut bb2 = h.block_builder( type_row![USIZE_T, QB_T], vec![type_row![]], - ExtensionSet::new(), + PRELUDE_ID.into(), type_row![QB_T, USIZE_T], )?; let [u, q] = bb2.input_wires_arr(); @@ -349,7 +352,7 @@ mod test { let mut bb3 = h.block_builder( type_row![QB_T, USIZE_T], vec![type_row![]], - ExtensionSet::new(), + PRELUDE_ID.into(), res_t.clone().into(), )?; let [q, u] = bb3.input_wires_arr(); diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index ed3b74086..e96dfa733 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -362,6 +362,7 @@ class TailLoop(DataflowOp): just_outputs: TypeRow = Field(default_factory=list) # Types that are only output # Types that are appended to both input and output: rest: TypeRow = Field(default_factory=list) + extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert in_types == out_types diff --git a/specification/schema/hugr_schema_strict_v1.json b/specification/schema/hugr_schema_strict_v1.json index d0c5aa92e..08f948996 100644 --- a/specification/schema/hugr_schema_strict_v1.json +++ b/specification/schema/hugr_schema_strict_v1.json @@ -1870,6 +1870,13 @@ }, "title": "Rest", "type": "array" + }, + "extension_delta": { + "items": { + "type": "string" + }, + "title": "Extension Delta", + "type": "array" } }, "required": [ diff --git a/specification/schema/hugr_schema_v1.json b/specification/schema/hugr_schema_v1.json index d5dab6428..8399c4feb 100644 --- a/specification/schema/hugr_schema_v1.json +++ b/specification/schema/hugr_schema_v1.json @@ -1870,6 +1870,13 @@ }, "title": "Rest", "type": "array" + }, + "extension_delta": { + "items": { + "type": "string" + }, + "title": "Extension Delta", + "type": "array" } }, "required": [ diff --git a/specification/schema/testing_hugr_schema_strict_v1.json b/specification/schema/testing_hugr_schema_strict_v1.json index 9c272a944..567c481dd 100644 --- a/specification/schema/testing_hugr_schema_strict_v1.json +++ b/specification/schema/testing_hugr_schema_strict_v1.json @@ -1947,6 +1947,13 @@ }, "title": "Rest", "type": "array" + }, + "extension_delta": { + "items": { + "type": "string" + }, + "title": "Extension Delta", + "type": "array" } }, "required": [ diff --git a/specification/schema/testing_hugr_schema_v1.json b/specification/schema/testing_hugr_schema_v1.json index 01c3d6bb8..47ed10526 100644 --- a/specification/schema/testing_hugr_schema_v1.json +++ b/specification/schema/testing_hugr_schema_v1.json @@ -1947,6 +1947,13 @@ }, "title": "Rest", "type": "array" + }, + "extension_delta": { + "items": { + "type": "string" + }, + "title": "Extension Delta", + "type": "array" } }, "required": [