Skip to content

Commit

Permalink
fix: FuncDefns don't require that their extensions match their childr…
Browse files Browse the repository at this point in the history
…en (#688)

Resolves #673
  • Loading branch information
croyzor authored Nov 16, 2023
1 parent 762839d commit 41e15da
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 8 deletions.
14 changes: 13 additions & 1 deletion src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,19 @@ impl UnificationContext {
let m_input_node = self.make_or_get_meta(input, dir);
self.add_constraint(m_input_node, Constraint::Equal(m_input));
let m_output_node = self.make_or_get_meta(output, dir);
self.add_constraint(m_output_node, Constraint::Equal(m_output));
// If the parent node is a FuncDefn, it will have no
// op_signature, so the Incoming and Outgoing ports will
// have equal extension requirements.
// The function that it contains, however, may have an
// extension delta, so its output shouldn't be equal to the
// FuncDefn's output.
//
// TODO: Add a constraint that the extensions of the output
// node of a FuncDefn should be those of the input node plus
// the extension delta specified in the function signature.
if node_type.tag() != OpTag::FuncDefn {
self.add_constraint(m_output_node, Constraint::Equal(m_output));
}
}
}

Expand Down
84 changes: 83 additions & 1 deletion src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::error::Error;

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
Expand Down Expand Up @@ -940,3 +942,83 @@ fn sccs() {
Some(&ExtensionSet::from_iter([A, B, C, UNKNOWN_EXTENSION]))
);
}

#[test]
/// Note: This test is relying on the builder's `define_function` doing the
/// right thing: it takes input resources via a [`Signature`], which it passes
/// to `create_with_io`, creating concrete resource sets.
/// Inference can still fail for a valid FuncDefn hugr created without using
/// the builder API.
fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
let mut func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
)?;

let [w] = func_builder.input_wires_arr();
let lift = func_builder.add_dataflow_op(
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: A,
},
[w],
)?;
let [w] = lift.outputs_arr();
func_builder.finish_with_outputs([w])?;
builder.finish_prelude_hugr()?;
Ok(())
}

#[test]
fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
let mut func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
)?;

let [w] = func_builder.input_wires_arr();
let lift = func_builder.add_dataflow_op(
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: B,
},
[w],
)?;
let [w] = lift.outputs_arr();
func_builder.finish_with_outputs([w])?;
let result = builder.finish_prelude_hugr();
assert_matches!(
result,
Err(ValidationError::CantInfer(
InferExtensionError::MismatchedConcreteWithLocations { .. }
))
);
Ok(())
}

#[test]
// Test that the difference between a FuncDefn's input and output nodes is being
// constrained to be the same as the extension delta in the FuncDefn signature.
// The FuncDefn here is declared to add resource "A", but its body just wires
// the input to the output.
fn funcdefn_signature_mismatch2() -> Result<(), Box<dyn Error>> {
let mut builder = ModuleBuilder::new();
let func_builder = builder.define_function(
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
)?;

let [w] = func_builder.input_wires_arr();
func_builder.finish_with_outputs([w])?;
let result = builder.finish_prelude_hugr();
assert_matches!(result, Err(ValidationError::CantInfer(..)));
Ok(())
}
14 changes: 9 additions & 5 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,15 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Secondly that the node has correct children
self.validate_children(node, node_type)?;

// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
}
}

Ok(())
Expand Down
52 changes: 51 additions & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ fn extensions_mismatch() -> Result<(), BuildError> {
assert_matches!(
handle,
Err(ValidationError::ExtensionError(
ExtensionError::ParentIOExtensionMismatch { .. }
ExtensionError::TgtExceedsSrcExtensionsAtPort { .. }
))
);
Ok(())
Expand Down Expand Up @@ -752,3 +752,53 @@ fn invalid_types() {
SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1))
);
}

#[test]
fn parent_io_mismatch() {
// The DFG node declares that it has an empty extension delta,
// but it's child graph adds extension "XB", causing a mismatch.
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]),
}));

let input = hugr
.add_node_with_parent(
hugr.root(),
NodeType::new_pure(ops::Input {
types: type_row![USIZE_T],
}),
)
.unwrap();
let output = hugr
.add_node_with_parent(
hugr.root(),
NodeType::new(
ops::Output {
types: type_row![USIZE_T],
},
ExtensionSet::singleton(&XB),
),
)
.unwrap();

let lift = hugr
.add_node_with_parent(
hugr.root(),
NodeType::new_pure(ops::LeafOp::Lift {
type_row: type_row![USIZE_T],
new_extension: XB,
}),
)
.unwrap();

hugr.connect(input, 0, lift, 0).unwrap();
hugr.connect(lift, 0, output, 0).unwrap();

let result = hugr.validate(&PRELUDE_REGISTRY);
assert_matches!(
result,
Err(ValidationError::ExtensionError(
ExtensionError::ParentIOExtensionMismatch { .. }
))
);
}

0 comments on commit 41e15da

Please sign in to comment.