diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 641ef1ae2..746347409 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -97,8 +97,7 @@ pub trait Container { ExtensionSet::new(), ))?; - let db = - DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?; + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) } @@ -297,16 +296,15 @@ pub trait Dataflow: Container { fn dfg_builder( &mut self, signature: FunctionType, - input_extensions: Option, input_wires: impl IntoIterator, ) -> Result, BuildError> { let op = ops::DFG { signature: signature.clone(), }; - let nodetype = NodeType::new(op, input_extensions.clone()); + let nodetype = NodeType::new_auto(op); let (dfg_n, _) = add_node_with_wires(self, nodetype, input_wires.into_iter().collect())?; - DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature, input_extensions) + DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature) } /// Return a builder for a [`crate::ops::CFG`] node, diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 9de97bc9c..bf9fd6cf3 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -260,12 +260,7 @@ impl + AsRef> BlockBuilder { let mut node_outputs = vec![tuple_sum_type]; node_outputs.extend_from_slice(&other_outputs); let signature = FunctionType::new(inputs, TypeRow::from(node_outputs)); - let inp_ex = base - .as_ref() - .get_nodetype(block_n) - .input_extensions() - .cloned(); - let db = DFGBuilder::create_with_io(base, block_n, signature, inp_ex)?; + let db = DFGBuilder::create_with_io(base, block_n, signature)?; Ok(BlockBuilder::from_dfg_builder(db)) } diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 1e3441968..9e9438d43 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -138,7 +138,6 @@ impl + AsRef> ConditionalBuilder { self.hugr_mut(), case_node, FunctionType::new(inputs, outputs).with_extension_delta(&extension_delta), - None, )?; Ok(CaseBuilder::from_dfg_builder(dfg_builder)) @@ -197,7 +196,7 @@ impl CaseBuilder { }; let base = Hugr::new(NodeType::new_open(op)); let root = base.root(); - let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?; + let dfg_builder = DFGBuilder::create_with_io(base, root, signature)?; Ok(CaseBuilder::from_dfg_builder(dfg_builder)) } diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 1e711737d..065fc9c56 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -9,7 +9,7 @@ use crate::ops; use crate::types::{FunctionType, PolyFuncType}; -use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::extension::ExtensionRegistry; use crate::Node; use crate::{hugr::HugrMut, Hugr}; @@ -27,7 +27,6 @@ impl + AsRef> DFGBuilder { mut base: T, parent: Node, signature: FunctionType, - input_extensions: Option, ) -> Result { let num_in_wires = signature.input().len(); let num_out_wires = signature.output().len(); @@ -49,15 +48,8 @@ impl + AsRef> DFGBuilder { let output = ops::Output { types: signature.output().clone(), }; - base.as_mut() - .add_node_with_parent(parent, NodeType::new(input, input_extensions.clone()))?; - base.as_mut().add_node_with_parent( - parent, - NodeType::new( - output, - input_extensions.map(|inp| inp.union(&signature.extension_reqs)), - ), - )?; + base.as_mut().add_node_with_parent(parent, input)?; + base.as_mut().add_node_with_parent(parent, output)?; Ok(Self { base, @@ -81,7 +73,7 @@ impl DFGBuilder { }; let base = Hugr::new(NodeType::new_open(dfg_op)); let root = base.root(); - DFGBuilder::create_with_io(base, root, signature, None) + DFGBuilder::create_with_io(base, root, signature) } } @@ -156,7 +148,7 @@ impl FunctionBuilder { let base = Hugr::new(NodeType::new_pure(op)); let root = base.root(); - let db = DFGBuilder::create_with_io(base, root, body, Some(ExtensionSet::new()))?; + let db = DFGBuilder::create_with_io(base, root, body)?; Ok(Self::from_dfg_builder(db)) } } @@ -254,11 +246,8 @@ pub(crate) mod test { [int], )? .outputs_arr(); - let inner_builder = func_builder.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), - None, - [int], - )?; + let inner_builder = func_builder + .dfg_builder(FunctionType::new(type_row![NAT], type_row![NAT]), [int])?; let inner_id = n_identity(inner_builder)?; func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))? @@ -380,7 +369,7 @@ pub(crate) mod test { let i1 = noop.out_wire(0); let mut nested = - f_build.dfg_builder(FunctionType::new(type_row![], type_row![BIT]), None, [])?; + f_build.dfg_builder(FunctionType::new(type_row![], type_row![BIT]), [])?; let id = nested.add_dataflow_op(LeafOp::Noop { ty: BIT }, [i1])?; @@ -403,8 +392,7 @@ pub(crate) mod test { let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1])?; let i1 = noop.out_wire(0); - let mut nested = - f_build.dfg_builder(FunctionType::new(type_row![], type_row![QB]), None, [])?; + let mut nested = f_build.dfg_builder(FunctionType::new(type_row![], type_row![QB]), [])?; let id_res = nested.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1]); @@ -482,7 +470,7 @@ pub(crate) mod test { FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&ab_extensions); // A box which adds extensions A and B, via child Lift nodes - let mut add_ab = parent.dfg_builder(add_ab_sig, Some(ExtensionSet::new()), [w])?; + let mut add_ab = parent.dfg_builder(add_ab_sig, [w])?; let [w] = add_ab.input_wires_arr(); let lift_a = add_ab.add_dataflow_op( @@ -511,7 +499,7 @@ pub(crate) mod test { // Add another node (a sibling to add_ab) which adds extension C // via a child lift node - let mut add_c = parent.dfg_builder(add_c_sig, Some(ab_extensions.clone()), [w])?; + let mut add_c = parent.dfg_builder(add_c_sig, [w])?; let [w] = add_c.input_wires_arr(); let lift_c = add_c.add_dataflow_node( NodeType::new( @@ -536,10 +524,10 @@ pub(crate) mod test { fn non_cfg_ancestor() -> Result<(), BuildError> { let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]); let mut b = DFGBuilder::new(unit_sig.clone())?; - let b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?; + let b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?; let b_child_in_wire = b_child.input().out_wire(0); b_child.finish_with_outputs([])?; - let b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?; + let b_child_2 = b.dfg_builder(unit_sig.clone(), [])?; // DFG block has edge coming a sibling block, which is only valid for // CFGs @@ -560,17 +548,16 @@ pub(crate) mod test { fn no_relation_edge() -> Result<(), BuildError> { let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]); let mut b = DFGBuilder::new(unit_sig.clone())?; - let mut b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?; - let b_child_child = - b_child.dfg_builder(unit_sig.clone(), None, [b_child.input().out_wire(0)])?; + let mut b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?; + let b_child_child = b_child.dfg_builder(unit_sig.clone(), [b_child.input().out_wire(0)])?; let b_child_child_in_wire = b_child_child.input().out_wire(0); b_child_child.finish_with_outputs([])?; b_child.finish_with_outputs([])?; - let mut b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?; + let mut b_child_2 = b.dfg_builder(unit_sig.clone(), [])?; let b_child_2_child = - b_child_2.dfg_builder(unit_sig.clone(), None, [b_child_2.input().out_wire(0)])?; + b_child_2.dfg_builder(unit_sig.clone(), [b_child_2.input().out_wire(0)])?; let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]); diff --git a/src/builder/module.rs b/src/builder/module.rs index 174b8028f..7563cf6e5 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -90,7 +90,7 @@ impl + AsRef> ModuleBuilder { NodeType::new_pure(ops::FuncDefn { name, signature }), )?; - let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, None)?; + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) } diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index bbcddade7..29a5b6709 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -21,7 +21,7 @@ impl + AsRef> TailLoopBuilder { tail_loop: &ops::TailLoop, ) -> Result { let signature = FunctionType::new(tail_loop.body_input_row(), tail_loop.body_output_row()); - let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature, None)?; + let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature)?; Ok(TailLoopBuilder::from_dfg_builder(dfg_build)) } diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 5187f32eb..f0a441fc0 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -13,7 +13,7 @@ use super::ExtensionSet; use crate::{ hugr::views::HugrView, - ops::{OpTag, OpTrait}, + ops::{FuncDefn, OpTag, OpTrait, OpType}, types::EdgeKind, Direction, Node, }; @@ -186,6 +186,14 @@ struct UnificationContext { fresh_name: u32, } +fn make_constraint(delta: ExtensionSet, other: Meta) -> Constraint { + if delta.is_empty() { + Constraint::Equal(other) + } else { + Constraint::Plus(delta, other) + } +} + /// Invariant: Constraint::Plus always points to a fresh metavariable impl UnificationContext { /// Create a new unification context, and populate it with constraints from @@ -283,15 +291,16 @@ impl UnificationContext { // op_signature, so the Incoming and Outgoing ports will // have equal extension requirements. // The function that it contains, however, may have an - // extension delta, so its output shouldn't be equal to the - // FuncDefn's output. - // - // TODO: Add a constraint that the extensions of the output - // node of a FuncDefn should be those of the input node plus - // the extension delta specified in the function signature. - if node_type.tag() != OpTag::FuncDefn { - self.add_constraint(m_output_node, Constraint::Equal(m_output)); - } + // extension delta - and the Input/Output nodes relate to that instead. + self.add_constraint( + m_output_node, + match node_type.op() { + OpType::FuncDefn(FuncDefn { signature, .. }) => { + make_constraint(signature.body().extension_reqs.clone(), m_output) + } + _ => Constraint::Equal(m_output), + }, + ); } } @@ -314,24 +323,13 @@ impl UnificationContext { self.add_constraint(m_output, Constraint::Equal(m_exit)); } - match node_type.io_extensions() { - // Input extensions are open - None => { - let delta = node_type.op().extension_delta(); - let c = if delta.is_empty() { - Constraint::Equal(m_input) - } else { - Constraint::Plus(delta, m_input) - }; - self.add_constraint(m_output, c); - } - // We have a solution for everything! - Some((input_exts, output_exts)) => { - self.add_solution(m_input, input_exts.clone()); - self.add_solution(m_output, output_exts); - } + let delta = node_type.op().extension_delta(); + self.add_constraint(m_output, make_constraint(delta, m_input)); + if let Some(input_exts) = node_type.input_extensions() { + self.add_solution(m_input, input_exts.clone()); } } + // Separate loop so that we can assume that a metavariable has been // added for every (Node, Direction) in the graph already. for tgt_node in hugr.nodes() { diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 5714cd4e8..50b964ed6 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -3,7 +3,8 @@ use std::error::Error; use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, + BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::QB_T; use crate::extension::ExtensionId; @@ -154,36 +155,25 @@ fn plus() -> Result<(), InferExtensionError> { } #[test] -// This generates a solution that causes validation to fail -// because of a missing lift node +// We can't infer a solution here because of a missing lift node between +// Input and Output fn missing_lift_node() -> Result<(), Box> { - let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&A)), - })); - - let input = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Input { - types: type_row![NAT], - }), - )?; - - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Output { - types: type_row![NAT], - }), + let main = FunctionBuilder::new( + "main", + FunctionType::new_endo(type_row![NAT]) + .with_extension_delta(&ExtensionSet::singleton(&A)) + .into(), )?; - hugr.connect(input, 0, output, 0)?; + let [inps] = main.input_wires_arr(); + let result = main.finish_prelude_hugr_with_outputs([inps]); - // Fail to catch the actual error because it's a difference between I/O + // Don't try to match the actual error because it's a difference between I/O // nodes and their parents and `report_mismatch` isn't yet smart enough // to handle that. assert_matches!( - hugr.update_validate(&PRELUDE_REGISTRY), - Err(ValidationError::CantInfer(_)) + result, + Err(BuildError::InvalidHUGR(ValidationError::CantInfer(_))) ); Ok(()) } @@ -996,6 +986,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box> { result, Err(ValidationError::CantInfer( InferExtensionError::MismatchedConcreteWithLocations { .. } + | InferExtensionError::EdgeMismatch(ExtensionError::SrcExceedsTgtExtensions { .. }) )) ); Ok(()) diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 17b41a153..99eb83701 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -689,7 +689,7 @@ mod test { let case1 = case1.finish_with_outputs(foo.outputs())?.node(); let mut case2 = cond.case_builder(1)?; let bar = case2.add_dataflow_op(mk_op("bar"), case2.input_wires())?; - let mut baz_dfg = case2.dfg_builder(utou.clone(), None, bar.outputs())?; + let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs())?; let baz = baz_dfg.add_dataflow_op(mk_op("baz"), baz_dfg.input_wires())?; let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs())?; let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node(); diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 469017d86..0cf232400 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -263,7 +263,6 @@ pub(in crate::hugr::rewrite) mod test { let mut inner_builder = func_builder.dfg_builder( FunctionType::new_endo(type_row![QB, QB]).with_extension_delta(&delta), - None, [qb0, qb1], )?; let inner_graph = { diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index c20c8a767..f31083478 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -4,7 +4,6 @@ use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, - ModuleBuilder, }; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; use crate::extension::{ @@ -14,7 +13,7 @@ use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; -use crate::ops::{self, Const, LeafOp, OpType}; +use crate::ops::{self, Const, Input, LeafOp, OpType, Output}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::{self, NotOp, EXTENSION_ID}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; @@ -447,30 +446,35 @@ fn test_local_const() -> Result<(), HugrError> { #[test] /// A wire with no extension requirements is wired into a node which has -/// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed +/// [A,B] extensions required on its inputs and outputs. This could be fixed /// by adding a lift node, but for validation this is an error. fn missing_lift_node() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - let mut main = module_builder.define_function( - "main", - FunctionType::new(type_row![NAT], type_row![NAT]).into(), + let exset = ExtensionSet::from_iter([XA, XB]); + let mut main = Hugr::new(NodeType::new_pure(FuncDefn { + name: "main".into(), + signature: FunctionType::new_endo(type_row![NAT]) + .with_extension_delta(&exset) + .into(), + })); + let inp = main.add_node_with_parent( + main.root(), + NodeType::new_pure(Input { + types: type_row![NAT], + }), )?; - let [main_input] = main.input_wires_arr(); - - let f_builder = main.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), - // Inner DFG has extension requirements that the wire wont satisfy - Some(ExtensionSet::from_iter([XA, XB])), - [main_input], + let out = main.add_node_with_parent( + main.root(), + NodeType::new( + Output { + types: type_row![NAT], + }, + exset, + ), )?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); + main.connect(inp, 0, out, 0)?; assert_matches!( - handle, + main.validate(&PRELUDE_REGISTRY), Err(ValidationError::ExtensionError( ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } )) @@ -484,28 +488,41 @@ fn missing_lift_node() -> Result<(), BuildError> { /// unification, so don't allow open extension variables on the function /// signature, so this fails. fn too_many_extension() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).into(); - - let mut main = module_builder.define_function("main", main_sig)?; - let [main_input] = main.input_wires_arr(); + let mut main = Hugr::new(NodeType::new_pure(FuncDefn { + name: "main".into(), + signature: FunctionType::new_endo(type_row![NAT]).into(), + })); + // Explicitly specify all input extensions so there is nothing to infer + let inp = main.add_node_with_parent( + main.root(), + NodeType::new_pure(Input { + types: type_row![NAT], + }), + )?; + let out = main.add_node_with_parent( + main.root(), + NodeType::new_pure(Output { + types: type_row![NAT], + }), + )?; + let lift = main.add_node_with_parent( + main.root(), + NodeType::new_pure(LeafOp::Lift { + type_row: type_row![NAT], + new_extension: XA, + }), + )?; - let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&XA)); + main.connect(inp, 0, lift, 0)?; + main.connect(lift, 0, out, 0)?; - let f_builder = main.dfg_builder(inner_sig, Some(ExtensionSet::new()), [main_input])?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); assert_matches!( - handle, + main.validate(&PRELUDE_REGISTRY), Err(ValidationError::ExtensionError( ExtensionError::SrcExceedsTgtExtensionsAtPort { .. } )) ); + Ok(()) } @@ -515,44 +532,44 @@ fn too_many_extension() -> Result<(), BuildError> { /// requirements `[A,BOOL_T]`. A slightly more complex test of the error from /// `missing_lift_node`. fn extensions_mismatch() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - let all_rs = ExtensionSet::from_iter([XA, XB]); - let main_sig = FunctionType::new(type_row![], type_row![NAT]) + let main_sig = FunctionType::new_endo(type_row![NAT, NAT]) .with_extension_delta(&all_rs) .into(); - let mut main = module_builder.define_function("main", main_sig)?; + let mut main = FunctionBuilder::new("main", main_sig)?; + let [left_wire, right_wire] = main.input_wires_arr(); let [left_wire] = main - .dfg_builder( - FunctionType::new(type_row![], type_row![NAT]), - Some(ExtensionSet::singleton(&XA)), - [], + .add_dataflow_node( + NodeType::new_pure(LeafOp::Lift { + type_row: type_row![NAT], + new_extension: XA, + }), + [left_wire], )? - .finish_with_outputs([])? .outputs_arr(); let [right_wire] = main - .dfg_builder( - FunctionType::new(type_row![], type_row![NAT]), - Some(ExtensionSet::singleton(&XB)), - [], + .add_dataflow_node( + NodeType::new_pure(LeafOp::Lift { + type_row: type_row![NAT], + new_extension: XB, + }), + [right_wire], )? - .finish_with_outputs([])? .outputs_arr(); - - let builder = main.dfg_builder( - FunctionType::new(type_row![NAT, NAT], type_row![NAT]), - Some(all_rs), - [left_wire, right_wire], - )?; - let [_left, _right] = builder.input_wires_arr(); - let [output] = builder.finish_with_outputs([])?.outputs_arr(); - - main.finish_with_outputs([output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); + main.set_outputs([left_wire, right_wire])?; + + // Avoid needing inference (which cannot succeed) by manually setting extensionsets + let mut hugr = main.hugr().clone(); + let [inp, out] = hugr.get_io(hugr.root()).unwrap(); + assert_eq!(hugr.get_nodetype(inp).input_extensions, None); + hugr.replace_op(inp, NodeType::new_pure(hugr.get_optype(inp).clone()))?; + assert_eq!(hugr.get_nodetype(out).input_extensions, None); + hugr.replace_op(out, NodeType::new(hugr.get_optype(out).clone(), all_rs))?; + let handle = hugr.validate(&PRELUDE_REGISTRY); assert_matches!( handle, Err(ValidationError::ExtensionError( diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index 8b89e9836..5743ab2cf 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -242,11 +242,8 @@ pub(super) mod test { .outputs_arr(); let inner_id = { - let inner_builder = func_builder.dfg_builder( - FunctionType::new_endo(type_row![NAT]), - None, - [int], - )?; + let inner_builder = + func_builder.dfg_builder(FunctionType::new_endo(type_row![NAT]), [int])?; let w = inner_builder.input_wires(); inner_builder.finish_with_outputs(w) }?; diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 4bf6c7ac2..1c1b55bc4 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -402,7 +402,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let fty = FunctionType::new(type_row![NAT], type_row![NAT]); let mut fbuild = module_builder.define_function("main", fty.clone().into())?; - let dfg = fbuild.dfg_builder(fty, None, fbuild.input_wires())?; + let dfg = fbuild.dfg_builder(fty, fbuild.input_wires())?; let ins = dfg.input_wires(); let sub_dfg = dfg.finish_with_outputs(ins)?; let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 116b429bf..716872f19 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -690,6 +690,7 @@ mod tests { use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; use crate::ops::LeafOp; + use crate::std_extensions::logic; use crate::utils::test_quantum_extension::{cx_gate, EXTENSION_ID}; use crate::{ builder::{ @@ -762,7 +763,12 @@ mod tests { fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare("test", FunctionType::new_endo(type_row![BOOL_T]).into())?; + let func = mod_builder.declare( + "test", + FunctionType::new_endo(type_row![BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&logic::EXTENSION_ID)) + .into(), + )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let outs1 = dfg.add_dataflow_op(NotOp, dfg.input_wires())?; @@ -781,7 +787,9 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).into(), + FunctionType::new_endo(type_row![BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&logic::EXTENSION_ID)) + .into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -987,7 +995,12 @@ mod tests { SiblingGraph::try_new(&hugr, func_root).unwrap(); let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); - assert_eq!(func_defn.signature, func.signature(&func_graph).into()); + assert_eq!( + func_defn.signature, + func.signature(&func_graph) + .with_extension_delta(&func_defn.signature.body().extension_reqs) + .into() + ); } #[test] @@ -996,10 +1009,10 @@ mod tests { let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); - let extracted = + let mut extracted = subgraph.extract_subgraph(&hugr, "region", &ExtensionSet::singleton(&EXTENSION_ID))?; - extracted.validate(&PRELUDE_REGISTRY).unwrap(); + extracted.update_validate(&PRELUDE_REGISTRY).unwrap(); Ok(()) } diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 43ccfbaee..07869a0b4 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -203,13 +203,10 @@ fn test_dataflow_ports_only() { assert_eq!(nt.input_extensions, Some(ExtensionSet::new())); nt.input_extensions = Some(ExtensionSet::singleton(&EXTENSION_ID)); } - // Note that presently the builder sets too many input-exts that could be - // left to the inference (https://github.com/CQCL/hugr/issues/702) hence we - // must manually change these too, although we can let inference deal with them + // Just (sanity-)check that no input-extensions have been set by the builder for node in dfg.hugr().get_io(local_and.node()).unwrap() { - let nt = dfg.hugr_mut().op_types.get_mut(node.pg_index()); - assert_eq!(nt.input_extensions, Some(ExtensionSet::new())); - nt.input_extensions = None; + let nt = dfg.hugr().op_types.get(node.pg_index()); + assert_eq!(nt.input_extensions, None); } let h = dfg