Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer extension constraints on FuncDefn; remove many ExtensionSets from builder #739

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ pub trait Container {
ExtensionSet::new(),
))?;

let db =
DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?;
let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

Expand Down Expand Up @@ -297,16 +296,15 @@ pub trait Dataflow: Container {
fn dfg_builder(
&mut self,
signature: FunctionType,
input_extensions: Option<ExtensionSet>,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
let op = ops::DFG {
signature: signature.clone(),
};
let nodetype = NodeType::new(op, input_extensions.clone());
let nodetype = NodeType::new_auto(op);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The choice of new_auto here is about the only non-obvious thing in the PR. ("pure" for FuncDefn but "open" for any other DFG.) I think the next PR might even remove new_auto and that could be a good point to say #702 is finished

let (dfg_n, _) = add_node_with_wires(self, nodetype, input_wires.into_iter().collect())?;

DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature, input_extensions)
DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
}

/// Return a builder for a [`crate::ops::CFG`] node,
Expand Down
7 changes: 1 addition & 6 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
let mut node_outputs = vec![tuple_sum_type];
node_outputs.extend_from_slice(&other_outputs);
let signature = FunctionType::new(inputs, TypeRow::from(node_outputs));
let inp_ex = base
.as_ref()
.get_nodetype(block_n)
.input_extensions()
.cloned();
let db = DFGBuilder::create_with_io(base, block_n, signature, inp_ex)?;
let db = DFGBuilder::create_with_io(base, block_n, signature)?;
Ok(BlockBuilder::from_dfg_builder(db))
}

Expand Down
3 changes: 1 addition & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
self.hugr_mut(),
case_node,
FunctionType::new(inputs, outputs).with_extension_delta(&extension_delta),
None,
)?;

Ok(CaseBuilder::from_dfg_builder(dfg_builder))
Expand Down Expand Up @@ -197,7 +196,7 @@ impl CaseBuilder<Hugr> {
};
let base = Hugr::new(NodeType::new_open(op));
let root = base.root();
let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?;
let dfg_builder = DFGBuilder::create_with_io(base, root, signature)?;

Ok(CaseBuilder::from_dfg_builder(dfg_builder))
}
Expand Down
47 changes: 17 additions & 30 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::ops;

use crate::types::{FunctionType, PolyFuncType};

use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::extension::ExtensionRegistry;
use crate::Node;
use crate::{hugr::HugrMut, Hugr};

Expand All @@ -27,7 +27,6 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
mut base: T,
parent: Node,
signature: FunctionType,
input_extensions: Option<ExtensionSet>,
) -> Result<Self, BuildError> {
let num_in_wires = signature.input().len();
let num_out_wires = signature.output().len();
Expand All @@ -49,15 +48,8 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
let output = ops::Output {
types: signature.output().clone(),
};
base.as_mut()
.add_node_with_parent(parent, NodeType::new(input, input_extensions.clone()))?;
base.as_mut().add_node_with_parent(
parent,
NodeType::new(
output,
input_extensions.map(|inp| inp.union(&signature.extension_reqs)),
),
)?;
base.as_mut().add_node_with_parent(parent, input)?;
base.as_mut().add_node_with_parent(parent, output)?;

Ok(Self {
base,
Expand All @@ -81,7 +73,7 @@ impl DFGBuilder<Hugr> {
};
let base = Hugr::new(NodeType::new_open(dfg_op));
let root = base.root();
DFGBuilder::create_with_io(base, root, signature, None)
DFGBuilder::create_with_io(base, root, signature)
}
}

Expand Down Expand Up @@ -156,7 +148,7 @@ impl FunctionBuilder<Hugr> {
let base = Hugr::new(NodeType::new_pure(op));
let root = base.root();

let db = DFGBuilder::create_with_io(base, root, body, Some(ExtensionSet::new()))?;
let db = DFGBuilder::create_with_io(base, root, body)?;
Ok(Self::from_dfg_builder(db))
}
}
Expand Down Expand Up @@ -254,11 +246,8 @@ pub(crate) mod test {
[int],
)?
.outputs_arr();
let inner_builder = func_builder.dfg_builder(
FunctionType::new(type_row![NAT], type_row![NAT]),
None,
[int],
)?;
let inner_builder = func_builder
.dfg_builder(FunctionType::new(type_row![NAT], type_row![NAT]), [int])?;
let inner_id = n_identity(inner_builder)?;

func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?
Expand Down Expand Up @@ -380,7 +369,7 @@ pub(crate) mod test {
let i1 = noop.out_wire(0);

let mut nested =
f_build.dfg_builder(FunctionType::new(type_row![], type_row![BIT]), None, [])?;
f_build.dfg_builder(FunctionType::new(type_row![], type_row![BIT]), [])?;

let id = nested.add_dataflow_op(LeafOp::Noop { ty: BIT }, [i1])?;

Expand All @@ -403,8 +392,7 @@ pub(crate) mod test {
let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1])?;
let i1 = noop.out_wire(0);

let mut nested =
f_build.dfg_builder(FunctionType::new(type_row![], type_row![QB]), None, [])?;
let mut nested = f_build.dfg_builder(FunctionType::new(type_row![], type_row![QB]), [])?;

let id_res = nested.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1]);

Expand Down Expand Up @@ -482,7 +470,7 @@ pub(crate) mod test {
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&ab_extensions);

// A box which adds extensions A and B, via child Lift nodes
let mut add_ab = parent.dfg_builder(add_ab_sig, Some(ExtensionSet::new()), [w])?;
let mut add_ab = parent.dfg_builder(add_ab_sig, [w])?;
let [w] = add_ab.input_wires_arr();

let lift_a = add_ab.add_dataflow_op(
Expand Down Expand Up @@ -511,7 +499,7 @@ pub(crate) mod test {

// Add another node (a sibling to add_ab) which adds extension C
// via a child lift node
let mut add_c = parent.dfg_builder(add_c_sig, Some(ab_extensions.clone()), [w])?;
let mut add_c = parent.dfg_builder(add_c_sig, [w])?;
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_node(
NodeType::new(
Expand All @@ -536,10 +524,10 @@ pub(crate) mod test {
fn non_cfg_ancestor() -> Result<(), BuildError> {
let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]);
let mut b = DFGBuilder::new(unit_sig.clone())?;
let b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?;
let b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
let b_child_in_wire = b_child.input().out_wire(0);
b_child.finish_with_outputs([])?;
let b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?;
let b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;

// DFG block has edge coming a sibling block, which is only valid for
// CFGs
Expand All @@ -560,17 +548,16 @@ pub(crate) mod test {
fn no_relation_edge() -> Result<(), BuildError> {
let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]);
let mut b = DFGBuilder::new(unit_sig.clone())?;
let mut b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?;
let b_child_child =
b_child.dfg_builder(unit_sig.clone(), None, [b_child.input().out_wire(0)])?;
let mut b_child = b.dfg_builder(unit_sig.clone(), [b.input().out_wire(0)])?;
let b_child_child = b_child.dfg_builder(unit_sig.clone(), [b_child.input().out_wire(0)])?;
let b_child_child_in_wire = b_child_child.input().out_wire(0);

b_child_child.finish_with_outputs([])?;
b_child.finish_with_outputs([])?;

let mut b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?;
let mut b_child_2 = b.dfg_builder(unit_sig.clone(), [])?;
let b_child_2_child =
b_child_2.dfg_builder(unit_sig.clone(), None, [b_child_2.input().out_wire(0)])?;
b_child_2.dfg_builder(unit_sig.clone(), [b_child_2.input().out_wire(0)])?;

let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);

Expand Down
2 changes: 1 addition & 1 deletion src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
NodeType::new_pure(ops::FuncDefn { name, signature }),
)?;

let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, None)?;
let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

Expand Down
2 changes: 1 addition & 1 deletion src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
tail_loop: &ops::TailLoop,
) -> Result<Self, BuildError> {
let signature = FunctionType::new(tail_loop.body_input_row(), tail_loop.body_output_row());
let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature, None)?;
let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature)?;

Ok(TailLoopBuilder::from_dfg_builder(dfg_build))
}
Expand Down
50 changes: 24 additions & 26 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
use super::ExtensionSet;
use crate::{
hugr::views::HugrView,
ops::{OpTag, OpTrait},
ops::{FuncDefn, OpTag, OpTrait, OpType},
types::EdgeKind,
Direction, Node,
};
Expand Down Expand Up @@ -186,6 +186,14 @@ struct UnificationContext {
fresh_name: u32,
}

fn make_constraint(delta: ExtensionSet, other: Meta) -> Constraint {
if delta.is_empty() {
Constraint::Equal(other)
} else {
Constraint::Plus(delta, other)
}
}

/// Invariant: Constraint::Plus always points to a fresh metavariable
impl UnificationContext {
/// Create a new unification context, and populate it with constraints from
Expand Down Expand Up @@ -283,15 +291,16 @@ impl UnificationContext {
// op_signature, so the Incoming and Outgoing ports will
// have equal extension requirements.
// The function that it contains, however, may have an
// extension delta, so its output shouldn't be equal to the
// FuncDefn's output.
//
// TODO: Add a constraint that the extensions of the output
// node of a FuncDefn should be those of the input node plus
// the extension delta specified in the function signature.
if node_type.tag() != OpTag::FuncDefn {
self.add_constraint(m_output_node, Constraint::Equal(m_output));
}
// extension delta - and the Input/Output nodes relate to that instead.
self.add_constraint(
m_output_node,
match node_type.op() {
OpType::FuncDefn(FuncDefn { signature, .. }) => {
make_constraint(signature.body().extension_reqs.clone(), m_output)
}
_ => Constraint::Equal(m_output),
},
);
}
}

Expand All @@ -314,24 +323,13 @@ impl UnificationContext {
self.add_constraint(m_output, Constraint::Equal(m_exit));
}

match node_type.io_extensions() {
// Input extensions are open
None => {
let delta = node_type.op().extension_delta();
let c = if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
}
// We have a solution for everything!
Some((input_exts, output_exts)) => {
self.add_solution(m_input, input_exts.clone());
self.add_solution(m_output, output_exts);
}
let delta = node_type.op().extension_delta();
self.add_constraint(m_output, make_constraint(delta, m_input));
if let Some(input_exts) = node_type.input_extensions() {
self.add_solution(m_input, input_exts.clone());
}
}

// Separate loop so that we can assume that a metavariable has been
// added for every (Node, Direction) in the graph already.
for tgt_node in hugr.nodes() {
Expand Down
39 changes: 15 additions & 24 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::error::Error;
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::QB_T;
use crate::extension::ExtensionId;
Expand Down Expand Up @@ -154,36 +155,25 @@ fn plus() -> Result<(), InferExtensionError> {
}

#[test]
// This generates a solution that causes validation to fail
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was incorrect (it doesn't generate a solution - the error is CantInfer). But the body/subject of this test has largely switched with the identically-named test in validation.

// because of a missing lift node
// We can't infer a solution here because of a missing lift node between
// Input and Output
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A)),
}));

let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::new_pure(ops::Input {
types: type_row![NAT],
}),
)?;

let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::new_pure(ops::Output {
types: type_row![NAT],
}),
let main = FunctionBuilder::new(
"main",
FunctionType::new_endo(type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.into(),
)?;

hugr.connect(input, 0, output, 0)?;
let [inps] = main.input_wires_arr();
let result = main.finish_prelude_hugr_with_outputs([inps]);

// Fail to catch the actual error because it's a difference between I/O
// Don't try to match the actual error because it's a difference between I/O
// nodes and their parents and `report_mismatch` isn't yet smart enough
// to handle that.
assert_matches!(
hugr.update_validate(&PRELUDE_REGISTRY),
Err(ValidationError::CantInfer(_))
result,
Err(BuildError::InvalidHUGR(ValidationError::CantInfer(_)))
);
Ok(())
}
Expand Down Expand Up @@ -996,6 +986,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
result,
Err(ValidationError::CantInfer(
InferExtensionError::MismatchedConcreteWithLocations { .. }
| InferExtensionError::EdgeMismatch(ExtensionError::SrcExceedsTgtExtensions { .. })
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is a bit nondeterministic as to which error is reported

))
);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ mod test {
let case1 = case1.finish_with_outputs(foo.outputs())?.node();
let mut case2 = cond.case_builder(1)?;
let bar = case2.add_dataflow_op(mk_op("bar"), case2.input_wires())?;
let mut baz_dfg = case2.dfg_builder(utou.clone(), None, bar.outputs())?;
let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs())?;
let baz = baz_dfg.add_dataflow_op(mk_op("baz"), baz_dfg.input_wires())?;
let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs())?;
let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node();
Expand Down
1 change: 0 additions & 1 deletion src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ pub(in crate::hugr::rewrite) mod test {

let mut inner_builder = func_builder.dfg_builder(
FunctionType::new_endo(type_row![QB, QB]).with_extension_delta(&delta),
None,
[qb0, qb1],
)?;
let inner_graph = {
Expand Down
Loading