Skip to content

Commit

Permalink
feat: Helper functions for requesting inference, use with builder in …
Browse files Browse the repository at this point in the history
…tests (#1219)

* Add `builder::ft2` that takes 2 `Into<Typerow>`s, and builds a
FunctionType with extension delta TO_BE_INFERRED
* And `builder::ft1` that takes a single `Into<TypeRow>` and makes an
endomorphic FunctionType similarly
* Use these to update a bunch of tests made worse by earlier PRs (i.e.
these now use inference rather than manually specifying deltas)
* Correct some doc comments....
* ....and add a `dfg_builder_endo` method that was hinted at by one of
the incorrect doc comments, and that infers the delta.
  • Loading branch information
acl-cqc authored Jun 28, 2024
1 parent 2b05771 commit 320a9a7
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 94 deletions.
16 changes: 14 additions & 2 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@
//! ```
use thiserror::Error;

use crate::extension::SignatureError;
use crate::extension::{SignatureError, TO_BE_INFERRED};
use crate::hugr::ValidationError;
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::ops::{NamedOp, OpType};
use crate::types::ConstTypeError;
use crate::types::Type;
use crate::types::{ConstTypeError, FunctionType, TypeRow};
use crate::{Node, Port, Wire};

pub mod handle;
Expand Down Expand Up @@ -121,6 +121,18 @@ pub use conditional::{CaseBuilder, ConditionalBuilder};
mod circuit;
pub use circuit::{CircuitBuildError, CircuitBuilder};

/// Return a FunctionType with the same input and output types (specified)
/// whose extension delta, when used in a non-FuncDefn container, will be inferred.
pub fn ft1(types: impl Into<TypeRow>) -> FunctionType {
FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED)
}

/// Return a FunctionType with the specified input and output types
/// whose extension delta, when used in a non-FuncDefn container, will be inferred.
pub fn ft2(inputs: impl Into<TypeRow>, outputs: impl Into<TypeRow>) -> FunctionType {
FunctionType::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED)
}

#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
/// Error while building the HUGR.
Expand Down
26 changes: 21 additions & 5 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use crate::{
types::EdgeKind,
};

use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY};
use crate::extension::{
ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY, TO_BE_INFERRED,
};
use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow};

use itertools::Itertools;
Expand Down Expand Up @@ -263,10 +265,9 @@ pub trait Dataflow: Container {
collect_array(self.input_wires())
}

/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph.
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph,
/// given a signature describing its input and output types and extension delta,
/// and the input wires (which must match the input types)
///
/// # Errors
///
Expand All @@ -286,6 +287,21 @@ pub trait Dataflow: Container {
DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
}

/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph,
/// that is endomorphic (the output types are the same as the input types).
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
fn dfg_builder_endo(
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
self.dfg_builder(
FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED),
input_wires,
)
}

/// Return a builder for a [`crate::ops::CFG`] node,
/// i.e. a nested controlflow subgraph.
/// The `inputs` must be an iterable over pairs of the type of the input and
Expand Down
26 changes: 7 additions & 19 deletions hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
}

impl DFGBuilder<Hugr> {
/// Begin building a new DFG rooted HUGR.
/// Input extensions default to being an open variable
/// Begin building a new DFG-rooted HUGR given its inputs, outputs,
/// and extension delta.
///
/// # Errors
///
Expand Down Expand Up @@ -203,11 +203,9 @@ pub(crate) mod test {
use serde_json::json;

use crate::builder::build_traits::DataflowHugr;
use crate::builder::{BuilderWiringError, DataflowSubContainer, ModuleBuilder};
use crate::builder::{ft1, BuilderWiringError, DataflowSubContainer, ModuleBuilder};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::{
ExtensionId, ExtensionSet, SignatureError, EMPTY_REG, PRELUDE_REGISTRY,
};
use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::OpTrait;
use crate::ops::{handle::NodeHandle, Lift, Noop, OpTag};
Expand Down Expand Up @@ -421,23 +419,13 @@ pub(crate) mod test {
let xa: ExtensionId = "A".try_into().unwrap();
let xb: ExtensionId = "B".try_into().unwrap();
let xc: ExtensionId = "C".try_into().unwrap();
let ab_extensions = ExtensionSet::from_iter([xa.clone(), xb.clone()]);
let abc_extensions = ab_extensions.clone().union(xc.clone().into());

let parent_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(abc_extensions);
let mut parent = DFGBuilder::new(parent_sig)?;

let add_c_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(xc.clone());
let mut parent = DFGBuilder::new(ft1(BIT))?;

let [w] = parent.input_wires_arr();

let add_ab_sig = FunctionType::new(type_row![BIT], type_row![BIT])
.with_extension_delta(ab_extensions.clone());

// A box which adds extensions A and B, via child Lift nodes
let mut add_ab = parent.dfg_builder(add_ab_sig, [w])?;
let mut add_ab = parent.dfg_builder(ft1(BIT), [w])?;
let [w] = add_ab.input_wires_arr();

let lift_a = add_ab.add_dataflow_op(
Expand All @@ -463,7 +451,7 @@ pub(crate) mod test {

// Add another node (a sibling to add_ab) which adds extension C
// via a child lift node
let mut add_c = parent.dfg_builder(add_c_sig, [w])?;
let mut add_c = parent.dfg_builder(ft1(BIT), [w])?;
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_op(
Lift {
Expand Down
15 changes: 4 additions & 11 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl CustomConst for ConstExternalSymbol {
#[cfg(test)]
mod test {
use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
builder::{ft1, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Hugr, Wire,
};
Expand Down Expand Up @@ -452,9 +452,7 @@ mod test {
assert!(error_val.equal_consts(&ConstError::new(2, "my message")));
assert!(!error_val.equal_consts(&ConstError::new(3, "my message")));

let mut b =
DFGBuilder::new(FunctionType::new_endo(type_row![]).with_extension_delta(PRELUDE_ID))
.unwrap();
let mut b = DFGBuilder::new(ft1(type_row![])).unwrap();

let err = b.add_load_value(error_val);

Expand Down Expand Up @@ -488,10 +486,7 @@ mod test {
)
.unwrap();

let mut b = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(PRELUDE_ID),
)
.unwrap();
let mut b = DFGBuilder::new(ft1(type_row![QB_T, QB_T])).unwrap();
let [q0, q1] = b.input_wires_arr();
let [q0, q1] = b
.add_dataflow_op(cx_gate(), [q0, q1])
Expand Down Expand Up @@ -529,9 +524,7 @@ mod test {
#[test]
/// Test print operation
fn test_print() {
let mut b: DFGBuilder<Hugr> =
DFGBuilder::new(FunctionType::new_endo(vec![]).with_extension_delta(PRELUDE_ID))
.unwrap();
let mut b: DFGBuilder<Hugr> = DFGBuilder::new(ft1(vec![])).unwrap();
let greeting: ConstString = ConstString::new("Hello, world!".into());
let greeting_out: Wire = b.add_load_value(greeting);
let print_op = PRELUDE
Expand Down
28 changes: 6 additions & 22 deletions hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod test {
use rstest::rstest;

use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
ft1, ft2, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
};
use crate::extension::prelude::QB_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
Expand Down Expand Up @@ -166,7 +166,6 @@ mod test {
#[case(true)]
#[case(false)]
fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
let delta = ExtensionSet::from_iter([int_ops::EXTENSION_ID, int_types::EXTENSION_ID]);
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
int_ops::EXTENSION.to_owned(),
Expand All @@ -175,10 +174,7 @@ mod test {
.unwrap();
let int_ty = &int_types::INT_TYPES[6];

let mut outer = DFGBuilder::new(
FunctionType::new(vec![int_ty.clone(); 2], vec![int_ty.clone()])
.with_extension_delta(delta.clone()),
)?;
let mut outer = DFGBuilder::new(ft2(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
let [a, b] = outer.input_wires_arr();
fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
d: &mut DFGBuilder<T>,
Expand All @@ -199,10 +195,7 @@ mod test {
}
let c1 = nonlocal.then(|| make_const(&mut outer));
let inner = {
let mut inner = outer.dfg_builder(
FunctionType::new_endo(vec![int_ty.clone()]).with_extension_delta(delta),
[a],
)?;
let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
let [a] = inner.input_wires_arr();
let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
Expand Down Expand Up @@ -251,10 +244,7 @@ mod test {

#[test]
fn permutation() -> Result<(), Box<dyn std::error::Error>> {
let mut h = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID),
)?;
let mut h = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?;
let [p, q] = h.input_wires_arr();
let [p_h] = h
.add_dataflow_op(test_quantum_extension::h_gate(), [p])?
Expand Down Expand Up @@ -349,17 +339,11 @@ mod test {
PRELUDE.to_owned(),
])
.unwrap();
let mut outer = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(float_types::EXTENSION_ID),
)?;
let mut outer = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?;
let [a, b] = outer.input_wires_arr();
let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
let mut inner = outer.dfg_builder(
FunctionType::new_endo(type_row![QB_T]).with_extension_delta(float_types::EXTENSION_ID),
h_b.outputs(),
)?;
let mut inner = outer.dfg_builder(ft1(QB_T), h_b.outputs())?;
let [i] = inner.input_wires_arr();
let f = inner.add_load_value(float_types::ConstF64::new(1.0));
inner.add_other_wire(inner.input().node(), f.node());
Expand Down
10 changes: 3 additions & 7 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use crate::builder::{
test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
ft2, test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T};
Expand All @@ -11,7 +11,7 @@ use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG};
use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::int_types::{self, int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::logic::NotOp;
use crate::types::{
type_param::TypeParam, FunctionType, PolyFuncType, SumType, Type, TypeArg, TypeBound,
Expand Down Expand Up @@ -351,11 +351,7 @@ fn hierarchy_order() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn constants_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
let mut builder = DFGBuilder::new(
FunctionType::new(vec![], vec![INT_TYPES[4].clone()])
.with_extension_delta(int_types::EXTENSION_ID),
)
.unwrap();
let mut builder = DFGBuilder::new(ft2(vec![], vec![INT_TYPES[4].clone()])).unwrap();
let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap());
let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?;

Expand Down
13 changes: 5 additions & 8 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rstest::rstest;
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
ft2, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, QB_T, USIZE_T};
Expand Down Expand Up @@ -769,13 +769,10 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {

let int_pair = Type::new_tuple(type_row![USIZE_T; 2]);
// Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints
let mut d = DFGBuilder::new(
FunctionType::new(
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
)
.with_extension_delta(PRELUDE_ID),
)?;
let mut d = DFGBuilder::new(ft2(
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
))?;
// ....by calling a function parametrized<extensions E> (int--e-->int, int_pair) -> int_pair
let f = {
let es = ExtensionSet::type_var(0);
Expand Down
9 changes: 3 additions & 6 deletions hugr-core/src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use portgraph::PortOffset;
use rstest::{fixture, rstest};

use crate::{
builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
builder::{ft2, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
ops::{
handle::{DataflowOpID, NodeHandle},
Expand Down Expand Up @@ -150,12 +150,9 @@ fn value_types() {

#[test]
fn static_targets() {
use crate::extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T};
use crate::extension::prelude::{ConstUsize, USIZE_T};
use itertools::Itertools;
let mut dfg = DFGBuilder::new(
FunctionType::new(type_row![], type_row![USIZE_T]).with_extension_delta(PRELUDE_ID),
)
.unwrap();
let mut dfg = DFGBuilder::new(ft2(type_row![], type_row![USIZE_T])).unwrap();

let c = dfg.add_constant(Value::extension(ConstUsize::new(1)));

Expand Down
13 changes: 5 additions & 8 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ pub type ValueNameRef = str;
#[cfg(test)]
mod test {
use super::Value;
use crate::builder::ft2;
use crate::builder::test::simple_dfg_hugr;
use crate::extension::prelude::PRELUDE_ID;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::{
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
Expand Down Expand Up @@ -521,13 +521,10 @@ mod test {
let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], Type::EMPTY_TYPEROW];
let pred_ty = SumType::new(pred_rows.clone());

let mut b = DFGBuilder::new(
FunctionType::new(type_row![], TypeRow::from(vec![pred_ty.clone().into()]))
.with_extension_delta(ExtensionSet::from_iter([
float_types::EXTENSION_ID,
PRELUDE_ID,
])),
)?;
let mut b = DFGBuilder::new(ft2(
type_row![],
TypeRow::from(vec![pred_ty.clone().into()]),
))?;
let c = b.add_constant(Value::sum(
0,
[
Expand Down
8 changes: 2 additions & 6 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::collections::{BTreeSet, HashMap};

use hugr_core::extension::ExtensionSet;
use hugr_core::builder::ft2;
use itertools::Itertools;
use thiserror::Error;

Expand All @@ -19,7 +19,6 @@ use hugr_core::{
},
ops::{OpType, Value},
type_row,
types::FunctionType,
utils::sorted_consts,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};
Expand Down Expand Up @@ -137,10 +136,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
/// against `reg`.
fn const_graph(consts: Vec<Value>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Value::get_type).collect_vec();
let exts = ExtensionSet::union_over(consts.iter().map(Value::extension_reqs));
let mut b =
DFGBuilder::new(FunctionType::new(type_row![], const_types).with_extension_delta(exts))
.unwrap();
let mut b = DFGBuilder::new(ft2(type_row![], const_types)).unwrap();

let outputs = consts
.into_iter()
Expand Down

0 comments on commit 320a9a7

Please sign in to comment.