Skip to content

Commit

Permalink
feat: Values (and hence Consts) know their extensions (#733)
Browse files Browse the repository at this point in the history
* Make each Value able to report its ExtensionSet needed at runtime (and
hence, also each CustomConst)
* Also give OpTrait a method for the extension-delta. This is the delta
of the FunctionType, *for dataflow ops*, but can be defined for other
optypes too - specifically (here), Const ops and also Case. Use this in
both inference and for `NodeType::io_extensions()`
* Input extensions of every Const node can now be the empty set (i.e.
the default-for-ModuleOp `pure`), as the relevant extension will be
added in the delta - thus, drop separate extension parameter in builder
(add_constant, add_load_const, etc.). LoadConstant ops can now also have
an empty delta and open input-extensions as these can be figured out
from the Const. This is thus a step towards fixing #702...
* Also note slight issue in replace::test::cfg, pending #388
* Add `ExtensionSet::union_over(impl IntoIterator<Item=Self>) -> Self`
utility method

This should ease the way towards solving the rest of #702 and moreover
to removing `new_auto` - we should be able to constrain the
input-extensions for any ModuleOp to `pure` in inference and thus
*every* node can be created open rather than a mix, but those are all
for follow-up PRs.
  • Loading branch information
acl-cqc authored Dec 12, 2023
1 parent c0d61c7 commit 679eefc
Show file tree
Hide file tree
Showing 21 changed files with 153 additions and 95 deletions.
12 changes: 6 additions & 6 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ pub(crate) mod test {
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down Expand Up @@ -887,8 +887,8 @@ pub(crate) mod test {
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
Expand Down Expand Up @@ -929,8 +929,8 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down
28 changes: 8 additions & 20 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,8 @@ pub trait Container {
///
/// This function will return an error if there is an error in adding the
/// [`OpType::Const`] node.
fn add_constant(
&mut self,
constant: ops::Const,
extensions: impl Into<Option<ExtensionSet>>,
) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, extensions.into()))?;
fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> {
let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?;

Ok(const_n.into())
}
Expand Down Expand Up @@ -356,20 +352,16 @@ pub trait Dataflow: Container {
fn load_const(&mut self, cid: &ConstID) -> Result<Wire, BuildError> {
let const_node = cid.node();
let nodetype = self.hugr().get_nodetype(const_node);
let input_extensions = nodetype.input_extensions().cloned();
let op: ops::Const = nodetype
.op()
.clone()
.try_into()
.expect("ConstID does not refer to Const op.");

let load_n = self.add_dataflow_node(
NodeType::new(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
input_extensions,
),
let load_n = self.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
)?;
Expand All @@ -382,12 +374,8 @@ pub trait Dataflow: Container {
/// # Errors
///
/// This function will return an error if there is an error when adding the node.
fn add_load_const(
&mut self,
constant: ops::Const,
extensions: ExtensionSet,
) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant, extensions)?;
fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> {
let cid = self.add_constant(constant)?;
self.load_const(&cid)
}

Expand Down
8 changes: 3 additions & 5 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ mod test {
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let c = middle_b.add_load_const(ops::Const::unary_unit_sum())?;
let [inw] = middle_b.input_wires_arr();
middle_b.finish_with_outputs(c, [inw])?
};
Expand All @@ -398,8 +398,7 @@ mod test {
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_variants = vec![type_row![]];

let mut entry_b =
Expand Down Expand Up @@ -427,8 +426,7 @@ mod test {
#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?;
let sum_variants = vec![type_row![]];
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ mod test {
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;
let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?;
let tru_const = fbuild.add_constant(Const::true_val())?;
let _fdef = {
let const_wire = fbuild.load_const(&tru_const)?;
let [int] = fbuild.input_wires_arr();
Expand Down
13 changes: 3 additions & 10 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ mod test {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_const(
ConstUsize::new(1).into(),
ExtensionSet::singleton(&PRELUDE_ID),
)?;
let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?;

let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
loop_b.set_outputs(break_wire, [i1])?;
Expand Down Expand Up @@ -148,8 +145,7 @@ mod test {
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
let signature = loop_b.loop_signature()?.clone();
let const_val = Const::true_val();
let const_wire =
loop_b.add_load_const(Const::true_val(), ExtensionSet::new())?;
let const_wire = loop_b.add_load_const(Const::true_val())?;
let lift_node = loop_b.add_dataflow_op(
ops::LeafOp::Lift {
type_row: vec![const_val.const_type().clone()].into(),
Expand Down Expand Up @@ -177,10 +173,7 @@ mod test {
let mut branch_1 = conditional_b.case_builder(1)?;
let [_b1] = branch_1.input_wires_arr();

let wire = branch_1.add_load_const(
ConstUsize::new(2).into(),
ExtensionSet::singleton(&PRELUDE_ID),
)?;
let wire = branch_1.add_load_const(ConstUsize::new(2).into())?;
let break_wire = branch_1.make_break(signature, [wire])?;
branch_1.finish_with_outputs([break_wire])?;

Expand Down
10 changes: 10 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ impl ExtensionSet {
self
}

/// Returns the union of an arbitrary collection of [ExtensionSet]s
pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
// `union` clones the receiver, which we do not need to do here
let mut res = ExtensionSet::new();
for s in sets {
res.0.extend(s.0)
}
res
}

/// The things in other which are in not in self
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
Expand Down
16 changes: 6 additions & 10 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,11 @@ impl UnificationContext {
match node_type.io_extensions() {
// Input extensions are open
None => {
let c = if let Some(sig) = node_type.op_signature() {
let delta = sig.extension_reqs;
if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
}
} else {
let delta = node_type.op().extension_delta();
let c = if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
}
Expand Down Expand Up @@ -703,7 +699,7 @@ impl UnificationContext {
});

let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip();
let solution = rs.iter().fold(ExtensionSet::new(), ExtensionSet::union);
let solution = ExtensionSet::union_over(rs);
let unresolved_metas = other_ms
.into_iter()
.filter(|other_m| m != *other_m)
Expand Down Expand Up @@ -731,7 +727,7 @@ impl UnificationContext {
Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)),
Constraint::Equal(_) => None,
})
.fold(ExtensionSet::new(), |a, b| a.union(b));
.fold(ExtensionSet::new(), ExtensionSet::union);

for m in cc.iter() {
self.add_solution(*m, combined_solution.clone());
Expand Down
6 changes: 5 additions & 1 deletion src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
Extension,
};

use super::{ExtensionRegistry, SignatureError, SignatureFromArgs};
use super::{ExtensionRegistry, ExtensionSet, SignatureError, SignatureFromArgs};
struct ArrayOpCustom;

const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()];
Expand Down Expand Up @@ -181,6 +181,10 @@ impl CustomConst for ConstUsize {
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::values::downcast_equal_consts(self, other)
}

fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(&PRELUDE_ID)
}
}

impl KnownTypeConst for ConstUsize {
Expand Down
13 changes: 3 additions & 10 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,9 @@ impl NodeType {
/// `None`` if the [Self::input_extensions] is `None`.
/// Otherwise, will return Some, with the output extensions computed from the node's delta
pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> {
self.input_extensions.as_ref().map(|e| {
(
e,
self.op
.dataflow_signature()
.map(|ft| ft.extension_reqs)
.unwrap_or_default()
.union(e),
)
})
self.input_extensions
.as_ref()
.map(|e| (e, self.op.extension_delta().union(e)))
}

/// Gets the underlying [OpType] i.e. without any [input_extensions]
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Rewrite for OutlineCfg {
.unwrap();
let cfg = cfg.finish_sub_container().unwrap();
let unit_sum = new_block_bldr
.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())
.add_constant(ops::Const::unary_unit_sum())
.unwrap();
let pred_wire = new_block_bldr.load_const(&unit_sum).unwrap();
new_block_bldr
Expand Down
7 changes: 4 additions & 3 deletions src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,15 @@ mod test {
.unwrap()
.into();
let just_list = TypeRow::from(vec![listy.clone()]);
let exset = ExtensionSet::singleton(&collections::EXTENSION_NAME);
let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]);

let mut cfg = CFGBuilder::new(
FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset),
// One might expect an extension_delta of "collections" here, but push/pop
// have an empty delta themselves, pending https://github.com/CQCL/hugr/issues/388
FunctionType::new_endo(just_list.clone()),
)?;

let pred_const = cfg.add_constant(ops::Const::unary_unit_sum(), None)?;
let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?;

let entry = single_node_block(&mut cfg, pop, &pred_const, true)?;
let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?;
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {
let empty_list = Value::Extension {
c: (Box::new(collections::ListValue::new(vec![])),),
};
let cst = def.add_load_const(Const::new(empty_list, list_of_var)?, just_colns)?;
let cst = def.add_load_const(Const::new(empty_list, list_of_var)?)?;
let res = def.finish_hugr_with_outputs([cst], &reg);
assert_matches!(
res.unwrap_err(),
Expand Down
18 changes: 11 additions & 7 deletions src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,22 @@ fn value_types() {
#[rustversion::since(1.75)] // uses impl in return position
#[test]
fn static_targets() {
use crate::extension::prelude::{ConstUsize, USIZE_T};
use crate::extension::{
prelude::{ConstUsize, PRELUDE_ID, USIZE_T},
ExtensionSet,
};
use itertools::Itertools;
let mut dfg = DFGBuilder::new(
FunctionType::new(type_row![], type_row![USIZE_T])
.with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)),
)
.unwrap();

let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap();

let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap();
let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap();

let load = dfg.load_const(&c).unwrap();

let h = dfg
.finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY)
.unwrap();
let h = dfg.finish_prelude_hugr_with_outputs([load]).unwrap();

assert_eq!(h.static_source(load.node()), Some(c.node()));

Expand Down
8 changes: 8 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod leaf;
pub mod module;
pub mod tag;
pub mod validate;
use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, FunctionType, Type};
use crate::{Direction, OutgoingPort, Port};
use crate::{IncomingPort, PortIndex};
Expand Down Expand Up @@ -278,6 +279,13 @@ pub trait OpTrait {
fn dataflow_signature(&self) -> Option<FunctionType> {
None
}

/// The delta between the input extensions specified for a node,
/// and the output extensions calculated for that node
fn extension_delta(&self) -> ExtensionSet {
ExtensionSet::new()
}

/// The edge kind for the non-dataflow or constant inputs of the operation,
/// not described by the signature.
///
Expand Down
Loading

0 comments on commit 679eefc

Please sign in to comment.