Skip to content

Commit

Permalink
refactor: Replace NodeType::signature() with io_extensions() (#700)
Browse files Browse the repository at this point in the history
The signature() method on NodeType basically only existed so the caller
could get both the input and output extensions (Signature providing the
`.output_extensions()` method). Instead, provide
`nodetype.io_extensions()` returning the same information.

This does leave Signature itself, which is a bit of a carbuncle.
  • Loading branch information
acl-cqc authored Nov 16, 2023
1 parent e7473f2 commit 762839d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 74 deletions.
10 changes: 5 additions & 5 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ impl UnificationContext {
where
T: HugrView,
{
if hugr.root_type().signature().is_none() {
if hugr.root_type().input_extensions().is_none() {
let m_input = self.make_or_get_meta(hugr.root(), Direction::Incoming);
self.variables.insert(m_input);
}
Expand Down Expand Up @@ -302,7 +302,7 @@ impl UnificationContext {
self.add_constraint(m_output, Constraint::Equal(m_exit));
}

match node_type.signature() {
match node_type.io_extensions() {
// Input extensions are open
None => {
let c = if let Some(sig) = node_type.op_signature() {
Expand All @@ -318,9 +318,9 @@ impl UnificationContext {
self.add_constraint(m_output, c);
}
// We have a solution for everything!
Some(sig) => {
self.add_solution(m_output, sig.output_extensions());
self.add_solution(m_input, sig.input_extensions);
Some((input_exts, output_exts)) => {
self.add_solution(m_input, input_exts.clone());
self.add_solution(m_output, output_exts);
}
}
}
Expand Down
36 changes: 7 additions & 29 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,26 +253,10 @@ fn dangling_src() -> Result<(), Box<dyn Error>> {

let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
assert_eq!(hugr.get_nodetype(src.node()).io_extensions().unwrap().1, rs);
assert_eq!(
hugr.get_nodetype(src.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
assert_eq!(
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.input_extensions,
rs
);
assert_eq!(
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.output_extensions(),
rs
hugr.get_nodetype(mult.node()).io_extensions().unwrap(),
(&rs.clone(), rs)
);
Ok(())
}
Expand Down Expand Up @@ -385,18 +369,12 @@ fn test_conditional_inference() -> Result<(), Box<dyn Error>> {

for node in [case0_node, case1_node, conditional_node] {
assert_eq!(
hugr.get_nodetype(node)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
hugr.get_nodetype(node).io_extensions().unwrap().0,
&ExtensionSet::new()
);
assert_eq!(
hugr.get_nodetype(node)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
hugr.get_nodetype(node).io_extensions().unwrap().0,
&ExtensionSet::new()
);
}
Ok(())
Expand Down
16 changes: 9 additions & 7 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ impl ExtensionValidator {
/// extension requirements for all of its input and output edges, then put
/// those requirements in the extension validation context.
fn gather_extensions(&mut self, node: &Node, node_type: &NodeType) {
if let Some(sig) = node_type.signature() {
for dir in Direction::BOTH {
assert!(self
.extensions
.insert((*node, dir), sig.get_extension(&dir))
.is_none());
}
if let Some((input_exts, output_exts)) = node_type.io_extensions() {
let prev_i = self
.extensions
.insert((*node, Direction::Incoming), input_exts.clone());
assert!(prev_i.is_none());
let prev_o = self
.extensions
.insert((*node, Direction::Outgoing), output_exts);
assert!(prev_o.is_none());
}
}

Expand Down
44 changes: 21 additions & 23 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::extension::{
};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait, OpType, DEFAULT_OPTYPE};
use crate::types::{FunctionType, Signature};
use crate::types::FunctionType;
use crate::{Direction, Node};

use delegate::delegate;
Expand Down Expand Up @@ -109,16 +109,6 @@ impl NodeType {
}
}

/// Use the input extensions to calculate the concrete signature of the node
pub fn signature(&self) -> Option<Signature> {
self.input_extensions.as_ref().map(|rs| {
self.op
.dataflow_signature()
.unwrap_or_default()
.with_input_extensions(rs.clone())
})
}

/// Get the function type from the embedded op
pub fn op_signature(&self) -> Option<FunctionType> {
self.op.dataflow_signature()
Expand All @@ -134,6 +124,23 @@ impl NodeType {
self.input_extensions.as_ref()
}

/// The input and output extensions for this node, if set.
///
/// `None`` if the [Self::input_extensions] is `None`.
/// Otherwise, will return Some, with the output extensions computed from the node's delta
pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> {
self.input_extensions.as_ref().map(|e| {
(
e,
self.op
.dataflow_signature()
.map(|ft| ft.extension_reqs)
.unwrap_or_default()
.union(e),
)
})
}

/// Gets the underlying [OpType] i.e. without any [input_extensions]
///
/// [input_extensions]: NodeType::input_extensions
Expand Down Expand Up @@ -411,19 +418,10 @@ mod test {
hugr.infer_extensions()?;

assert_eq!(
hugr.get_nodetype(lift)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.get_nodetype(output)
.signature()
.unwrap()
.input_extensions,
r
hugr.get_nodetype(lift).input_extensions().unwrap(),
&ExtensionSet::new()
);
assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r);
Ok(())
}
}
11 changes: 1 addition & 10 deletions src/types/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct FunctionType {
}

#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
/// A concrete signature, which has been instantiated with a set of input extensions
/// A combination of a FunctionType and a set of input extensions, used for declaring functions
pub struct Signature {
/// The underlying signature
pub signature: FunctionType,
Expand Down Expand Up @@ -244,15 +244,6 @@ impl FunctionType {
}

impl Signature {
/// Returns a reference to the extension set for the ports of the
/// signature in a given direction
pub fn get_extension(&self, dir: &Direction) -> ExtensionSet {
match dir {
Direction::Incoming => self.input_extensions.clone(),
Direction::Outgoing => self.output_extensions(),
}
}

delegate! {
to self.signature {
/// Inputs of the function type
Expand Down

0 comments on commit 762839d

Please sign in to comment.