From c50bf6e316bc236ce2d8f12e86669147baee7320 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 23 Sep 2024 14:00:49 +0100 Subject: [PATCH 1/2] fix: inline_constant_functions not connecting LoadFunction --- src/utils/inline_constant_functions.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/utils/inline_constant_functions.rs b/src/utils/inline_constant_functions.rs index 56524b3..0909185 100644 --- a/src/utils/inline_constant_functions.rs +++ b/src/utils/inline_constant_functions.rs @@ -1,9 +1,5 @@ use hugr::{ - extension::ExtensionRegistry, - hugr::hugrmut::HugrMut, - ops::{FuncDefn, LoadFunction, Value}, - types::PolyFuncType, - HugrView, Node, NodeIndex as _, + extension::ExtensionRegistry, hugr::hugrmut::HugrMut, ops::{FuncDefn, LoadFunction, Value}, types::PolyFuncType, HugrView, IncomingPort, Node, NodeIndex as _ }; use anyhow::{anyhow, bail, Result}; @@ -75,10 +71,12 @@ fn inline_constant_functions_impl( hugr.insert_hugr(func_node, func_hugr); for lcn in load_constant_ns { + hugr.disconnect(lcn, IncomingPort::from(0)); hugr.replace_op( lcn, LoadFunction::try_new(polysignature.clone(), [], registry)?, )?; + hugr.connect(func_node, 0, lcn, 0); } any_changes = true; } From 60f608e7375f306eeddc5f608d52c14394e4b03a Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 23 Sep 2024 14:38:19 +0100 Subject: [PATCH 2/2] fix --- src/utils/inline_constant_functions.rs | 44 ++++++++++++++------------ 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/utils/inline_constant_functions.rs b/src/utils/inline_constant_functions.rs index 0909185..2915f42 100644 --- a/src/utils/inline_constant_functions.rs +++ b/src/utils/inline_constant_functions.rs @@ -1,5 +1,10 @@ use hugr::{ - extension::ExtensionRegistry, hugr::hugrmut::HugrMut, ops::{FuncDefn, LoadFunction, Value}, types::PolyFuncType, HugrView, IncomingPort, Node, NodeIndex as _ + builder::{Dataflow, DataflowHugr, FunctionBuilder}, + extension::ExtensionRegistry, + hugr::hugrmut::HugrMut, + ops::{DataflowParent, FuncDefn, LoadFunction, Value}, + types::{PolyFuncType, Signature}, + HugrView, IncomingPort, Node, NodeIndex as _, }; use anyhow::{anyhow, bail, Result}; @@ -23,7 +28,7 @@ fn inline_constant_functions_impl( let mut const_funs = vec![]; for n in hugr.nodes() { - let konst_hugr = { + let (konst_hugr, sig) = { let Some(konst) = hugr.get_optype(n).as_const() else { continue; }; @@ -31,13 +36,24 @@ fn inline_constant_functions_impl( continue; }; let optype = hugr.get_optype(hugr.root()); - if !optype.is_dfg() && !optype.is_func_defn() { + if let Some(func) = optype.as_func_defn() { + (hugr.as_ref().clone(), func.inner_signature()) + } else if let Some(dfg) = optype.as_dfg() { + let signature: Signature = dfg.inner_signature(); + let mut builder = FunctionBuilder::new(const_fn_name(n), signature.clone())?; + let outputs = builder + .add_hugr_view_with_wires(hugr, builder.input_wires())? + .outputs(); + ( + builder.finish_hugr_with_outputs(outputs, registry)?, + signature, + ) + } else { bail!( "Constant function has unsupported root: {:?}", hugr.get_optype(hugr.root()) ) } - hugr.clone() }; let mut lcs = vec![]; for load_constant in hugr.output_neighbours(n) { @@ -49,32 +65,20 @@ fn inline_constant_functions_impl( } lcs.push(load_constant); } - const_funs.push((n, konst_hugr.as_ref().clone(), lcs)); + const_funs.push((n, konst_hugr.as_ref().clone(), sig, lcs)); } let mut any_changes = false; - for (konst_n, func_hugr, load_constant_ns) in const_funs { + for (konst_n, func_hugr, sig, load_constant_ns) in const_funs { if !load_constant_ns.is_empty() { - let polysignature: PolyFuncType = func_hugr - .inner_function_type() - .ok_or(anyhow!( - "Constant function hugr has no inner_func_type: {}", - konst_n.index() - ))? - .into(); - let func_defn = FuncDefn { - name: const_fn_name(konst_n), - signature: polysignature.clone(), - }; - let func_node = hugr.add_node_with_parent(hugr.root(), func_defn); - hugr.insert_hugr(func_node, func_hugr); + let func_node = hugr.insert_hugr(hugr.root(), func_hugr).new_root; for lcn in load_constant_ns { hugr.disconnect(lcn, IncomingPort::from(0)); hugr.replace_op( lcn, - LoadFunction::try_new(polysignature.clone(), [], registry)?, + LoadFunction::try_new(sig.clone().into(), [], registry)?, )?; hugr.connect(func_node, 0, lcn, 0); }