diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index b32d54a8c..f2b0dab24 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -3,8 +3,11 @@ use std::collections::{BTreeSet, HashMap}; use itertools::Itertools; +use thiserror::Error; +use crate::hugr::{SimpleReplacementError, ValidationError}; use crate::types::SumType; +use crate::Direction; use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::{ConstFoldResult, ExtensionRegistry}, @@ -19,6 +22,19 @@ use crate::{ Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum ConstFoldError { + #[error("Failed to verify {label} HUGR: {err}")] + VerifyError { + label: String, + #[source] + err: ValidationError, + }, + #[error(transparent)] + SimpleReplaceError(#[from] SimpleReplacementError), +} + /// Tag some output constants with [`OutgoingPort`] inferred from the ordering. fn out_row(consts: impl IntoIterator) -> ConstFoldResult { let vec = consts @@ -43,9 +59,10 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { .map(|(_, c)| c) .collect() } + /// For a given op and consts, attempt to evaluate the op. pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - match op { + let fold_result = match op { OpType::Noop { .. } => out_row([consts.first()?.1.clone()]), OpType::MakeTuple { .. } => { out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())]) @@ -69,7 +86,10 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR ext_op.constant_fold(consts) } _ => None, - } + }; + debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() + == op.value_port_count(Direction::Outgoing))); + fold_result } /// Generate a graph that loads and outputs `consts` in order, validating @@ -140,18 +160,16 @@ fn fold_op( }) .unzip(); // attempt to evaluate op - let folded = fold_leaf_op(neighbour_op, &in_consts)?; - let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); - let nu_out = op_outs + let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? .into_iter() .enumerate() - .filter_map(|(i, out)| { - // map from the ports the op was linked to, to the output ports of - // the replacement. - hugr.single_linked_input(op_node, out) - .map(|np| (np, i.into())) + .filter_map(|(i, (op_out, konst))| { + // for each used port of the op give the nu_out entry and the + // corresponding Value + hugr.single_linked_input(op_node, op_out) + .map(|np| ((np, i.into()), konst)) }) - .collect(); + .unzip(); let replacement = const_graph(consts, reg); let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) .expect("Operation should form valid subgraph."); @@ -172,11 +190,8 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; let load_op = hugr.get_optype(load_n).as_load_constant()?; let const_node = hugr - .linked_outputs(load_n, load_op.constant_port()) - .exactly_one() - .ok()? + .single_linked_output(load_n, load_op.constant_port())? .0; - let const_op = hugr.get_optype(const_node).as_const()?; // TODO avoid const clone here @@ -184,27 +199,45 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< } /// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { +pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { + #[cfg(test)] + let verify = |label, h: &H| { + h.validate_no_extensions(reg).unwrap_or_else(|err| { + panic!( + "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", + h.mermaid_string() + ) + }) + }; + #[cfg(test)] + verify("input", h); loop { - // would be preferable if the candidates were updated to be just the - // neighbouring nodes of those added. - let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); - if rewrites.is_empty() { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { break; - } - for (replace, removes) in rewrites { - h.apply_rewrite(replace).unwrap(); - for rem in removes { - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - if h.apply_rewrite(RemoveConst(const_node)).is_err() { - // const cannot be removed - no problem - continue; - } - } + }; + h.apply_rewrite(replace).unwrap(); + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = h.apply_rewrite(RemoveConst(const_node)); } } } + #[cfg(test)] + verify("output", h); } #[cfg(test)] @@ -395,4 +428,88 @@ mod test { let expected = Value::false_val(); assert_fully_folded(&h, &expected); } + + #[test] + fn orphan_output() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // x2 := or(x0,x1) + // output x2 == true; + // + // We arange things so that the `or` folds away first, leaving the not + // with no outputs. + use crate::hugr::NodeType; + use crate::ops::handle::NodeHandle; + + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let true_wire = build.add_load_value(Value::true_val()); + // this Not will be manually replaced + let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); + let r = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + [true_wire, orig_not.out_wire(0)], + ) + .unwrap(); + let or_node = r.node(); + let parent = build.dfg_node; + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + + // we delete the original Not and create a new One. This means it will be + // traversed by `constant_fold_pass` after the Or. + let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); + h.connect(true_wire.node(), true_wire.source(), new_not, 0); + h.disconnect(or_node, IncomingPort::from(1)); + h.connect(new_not, 0, or_node, 1); + h.remove_node(orig_not.node()); + constant_fold_pass(&mut h, ®); + assert_fully_folded(&h, &Value::true_val()) + } + + #[test] + fn test_folding_pass_issue_996() { + // pseudocode: + // + // x0 := 3.0 + // x1 := 4.0 + // x2 := fne(x0, x1); // true + // x3 := flt(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := -10.0 + // x6 := flt(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); + let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); + let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); + let x4 = build + .add_dataflow_op( + NaryLogic::And.with_n_inputs(2), + x2.outputs().chain(x3.outputs()), + ) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); + let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); + let x7 = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + x4.outputs().chain(x6.outputs()), + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } } diff --git a/hugr/src/hugr/views.rs b/hugr/src/hugr/views.rs index 5b3354dd9..a22cd309c 100644 --- a/hugr/src/hugr/views.rs +++ b/hugr/src/hugr/views.rs @@ -24,7 +24,10 @@ use itertools::{Itertools, MapInto}; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView}; -use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE}; +use super::{ + Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE, +}; +use crate::extension::ExtensionRegistry; use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; @@ -460,6 +463,18 @@ pub trait HugrView: sealed::HugrInternals { self.value_types(node, Direction::Outgoing) .map(|(p, t)| (p.as_outgoing().unwrap(), t)) } + + /// Check the validity of the underlying HUGR. + fn validate(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> { + self.base_hugr().validate(reg) + } + + /// Check the validity of the underlying HUGR, but don't check consistency + /// of extension requirements between connected nodes or between parents and + /// children. + fn validate_no_extensions(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> { + self.base_hugr().validate_no_extensions(reg) + } } /// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs index 0915a4737..8738e1872 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -16,6 +16,16 @@ use crate::{ use super::IntOpDef; +use lazy_static::lazy_static; + +lazy_static! { + static ref INARROW_ERROR_VALUE: Value = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + } + .into(); +} + fn bitmask_from_width(width: u64) -> u64 { debug_assert!(width <= 64); if width == 64 { @@ -111,28 +121,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { let logwidth0: u8 = get_log_width(arg0).ok()?; let logwidth1: u8 = get_log_width(arg1).ok()?; let n0: &ConstInt = get_single_input_value(consts)?; + (logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?; let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); let sum_type = sum_with_error(int_out_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - Value::sum(1, [err_val.into()], sum_type.clone()) + + let mk_out_const = |i, mb_v: Result| { + mb_v.and_then(|v| Value::sum(i, [v], sum_type)) .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) }; let n0val: u64 = n0.value_u(); let out_const: Value = if n0val >> (1 << logwidth1) != 0 { - err_value() + mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone())) } else { - Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap()) + mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into)) }; - if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { - None - } else { - Some(vec![(0.into(), out_const)]) - } + Some(vec![(0.into(), out_const)]) }, ), }, @@ -145,29 +149,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { let logwidth0: u8 = get_log_width(arg0).ok()?; let logwidth1: u8 = get_log_width(arg1).ok()?; let n0: &ConstInt = get_single_input_value(consts)?; + (logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?; let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); let sum_type = sum_with_error(int_out_type.clone()); - let err_value = || { - let err_val = ConstError { - signal: 0, - message: "Integer too large to narrow".to_string(), - }; - Value::sum(1, [err_val.into()], sum_type.clone()) + let mk_out_const = |i, mb_v: Result| { + mb_v.and_then(|v| Value::sum(i, [v], sum_type)) .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) }; let n0val: i64 = n0.value_s(); let ub = 1i64 << ((1 << logwidth1) - 1); let out_const: Value = if n0val >= ub || n0val < -ub { - err_value() + mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone())) } else { - Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap()) + mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into)) }; - if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { - None - } else { - Some(vec![(0.into(), out_const)]) - } + Some(vec![(0.into(), out_const)]) }, ), }, diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs index 5241bf230..959240a51 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -61,50 +61,46 @@ fn test_fold_iwiden_s() { assert_fully_folded(&h, &expected); } -#[test] -fn test_fold_inarrow_u() { - // pseudocode: - // - // x0 := int_u<5>(13); - // x1 := inarrow_u<5, 4>(x0); - // output x1 == int_u<4>(13); - let sum_type = sum_with_error(INT_TYPES[4].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 13).unwrap())); - let x1 = build - .add_dataflow_op(IntOpDef::inarrow_u.with_two_log_widths(5, 4), [x0]) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_u(4, 13).unwrap()); - assert_fully_folded(&h, &expected); -} - -#[test] -fn test_fold_inarrow_s() { - // pseudocode: +#[rstest] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 4, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 5, -3, true)] +#[case(ConstInt::new_s, IntOpDef::inarrow_s, 5, 1, -3, false)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 4, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 5, 13, true)] +#[case(ConstInt::new_u, IntOpDef::inarrow_u, 5, 0, 3, false)] +fn test_fold_inarrow, E: std::fmt::Debug>( + #[case] mk_const: impl Fn(u8, I) -> Result, + #[case] op_def: IntOpDef, + #[case] from_log_width: u8, + #[case] to_log_width: u8, + #[case] val: I, + #[case] succeeds: bool, +) { + // For the first case, pseudocode: // // x0 := int_s<5>(-3); // x1 := inarrow_s<5, 4>(x0); - // output x1 == int_s<4>(-3); - let sum_type = sum_with_error(INT_TYPES[4].to_owned()); + // output x1 == sum(-3)]>; + // + // Other cases vary by: + // (mk_const, op_def) => create signed or unsigned constants, create + // inarrow_s or inarrow_u ops; + // (from_log_width, to_log_width) => the args to use to create the op; + // val => the value to pass to the op + // succeeds => whether to expect a int variant or an error + // variant. + let sum_type = sum_with_error(INT_TYPES[to_log_width as usize].to_owned()); let mut build = DFGBuilder::new(FunctionType::new( type_row![], vec![sum_type.clone().into()], )) .unwrap(); - let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -3).unwrap())); + let x0 = build.add_load_const(mk_const(from_log_width, val).unwrap().into()); let x1 = build - .add_dataflow_op(IntOpDef::inarrow_s.with_two_log_widths(5, 4), [x0]) + .add_dataflow_op( + op_def.with_two_log_widths(from_log_width, to_log_width), + [x0], + ) .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), @@ -113,7 +109,11 @@ fn test_fold_inarrow_s() { .unwrap(); let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); constant_fold_pass(&mut h, ®); - let expected = Value::extension(ConstInt::new_s(4, -3).unwrap()); + let expected = if succeeds { + Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap() + } else { + Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap() + }; assert_fully_folded(&h, &expected); }