diff --git a/hugr-core/src/hugr/linking.rs b/hugr-core/src/hugr/linking.rs index f20c006e61..9963bbfdc4 100644 --- a/hugr-core/src/hugr/linking.rs +++ b/hugr-core/src/hugr/linking.rs @@ -2,12 +2,14 @@ use std::{collections::HashMap, fmt::Display}; -use itertools::Either; +use itertools::{Either, Itertools}; use crate::{ - Hugr, HugrView, Node, + Hugr, HugrView, Node, Visibility, core::HugrNode, hugr::{HugrMut, hugrmut::InsertedForest, internal::HugrMutInternals}, + ops::OpType, + types::PolyFuncType, }; /// Methods that merge Hugrs, adding static edges between old and inserted nodes. @@ -105,6 +107,67 @@ pub trait HugrLinking: HugrMut { link_by_node(self, transfers, &mut inserted.node_map); Ok(inserted) } + + /// Insert module-children from another Hugr into this one according to a [NameLinkingPolicy]. + /// + /// All [Visibility::Public] module-children are inserted, or linked, according to the + /// specified policy; private children will also be inserted, at least including all those + /// used by the copied public children. + // Yes at present we copy all private children, i.e. a safe over-approximation! + /// + /// # Errors + /// + /// If [NameLinkingPolicy::on_signature_conflict] or [NameLinkingPolicy::on_multiple_impls] + /// are set to [NewFuncHandling::RaiseError], and the respective conflict occurs between + /// `self` and `other`. + /// + /// [Visibility::Public]: crate::Visibility::Public + /// [FuncDefn]: crate::ops::FuncDefn + fn link_module( + &mut self, + other: Hugr, + policy: &NameLinkingPolicy, + ) -> Result, NameLinkingError> { + let actions = policy.to_node_linking(self, &other)?; + let directives = actions + .into_iter() + .map(|(k, LinkAction::LinkNode(d))| (k, d)) + .collect(); + Ok(self + .insert_link_hugr_by_node(None, other, directives) + .expect("NodeLinkingPolicy was constructed to avoid any error")) + } + + /// Copy module-children from another Hugr into this one according to a [NameLinkingPolicy]. + /// + /// All [Visibility::Public] module-children are copied, or linked, according to the + /// specified policy; private children will also be copied, at least including all those + /// used by the copied public children. + // Yes at present we copy all private children, i.e. a safe over-approximation! + /// + /// # Errors + /// + /// If [NameLinkingPolicy::on_signature_conflict] or [NameLinkingPolicy::on_multiple_impls] + /// are set to [NewFuncHandling::RaiseError], and the respective conflict occurs between + /// `self` and `other`. + /// + /// [Visibility::Public]: crate::Visibility::Public + /// [FuncDefn]: crate::ops::FuncDefn + #[allow(clippy::type_complexity)] + fn link_module_view( + &mut self, + other: &H, + policy: &NameLinkingPolicy, + ) -> Result, NameLinkingError> { + let actions = policy.to_node_linking(self, &other)?; + let directives = actions + .into_iter() + .map(|(k, LinkAction::LinkNode(d))| (k, d)) + .collect(); + Ok(self + .insert_link_view_by_node(None, other, directives) + .expect("NodeLinkingPolicy was constructed to avoid any error")) + } } impl HugrLinking for T {} @@ -185,12 +248,296 @@ impl NodeLinkingDirective { } } +/// Describes ways to link a "Source" Hugr being inserted into a target Hugr. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct NameLinkingPolicy { + // TODO: consider pub-funcs-to-add? (With others, optionally filtered by callgraph, made private) + // copy_private_funcs: bool, // TODO: allow filtering private funcs to only those reachable in callgraph + sig_conflict: NewFuncHandling, + // TODO consider Set of names where to prefer new? Or optional map from name? + multi_impls: MultipleImplHandling, + // TODO Renames to apply to public functions in the inserted Hugr. These take effect + // before [error_on_conflicting_sig] or [take_existing_and_new_impls]. + // rename_map: HashMap +} + +/// Specifies what to do with a function in some situation - used in +/// * [NameLinkingPolicy::on_signature_conflict] +/// * [MultipleImplHandling::NewFunc] +/// +/// [FuncDefn]: crate::ops::FuncDefn +/// [Visibility::Public]: crate::Visibility::Public +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[non_exhaustive] // could consider e.g. disconnections +pub enum NewFuncHandling { + /// Do not link the Hugrs together; fail with a [NameLinkingError] instead. + RaiseError, + /// Add the new function alongside the existing one in the target Hugr, + /// preserving (separately) uses of both. (The Hugr will be invalid because + /// of [duplicate names](crate::hugr::ValidationError::DuplicateExport).) + Add, +} + +/// What to do when both target and inserted Hugr +/// have a [Visibility::Public] FuncDefn with the same name and signature. +/// +/// [Visibility::Public]: crate::Visibility::Public +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, derive_more::From)] +#[non_exhaustive] // could consider e.g. disconnections +pub enum MultipleImplHandling { + /// Keep the implementation already in the target Hugr. (Edges in the source + /// Hugr will be redirected to use the function from the target.) + UseExisting, + /// Keep the implementation in the source Hugr. (Edges in the target Hugr + /// will be redirected to use the function from the source; the previously-existing + /// function in the target Hugr will be removed.) + UseNew, + /// Proceed as per the specified [NewFuncHandling]. + NewFunc(#[from] NewFuncHandling), +} + +/// An error in using names to determine how to link functions in source and target Hugrs. +/// (SN = Source Node, TN = Target Node) +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +pub enum NameLinkingError { + /// Both source and target contained a [FuncDefn] (public and with same name + /// and signature). + /// + /// [FuncDefn]: crate::ops::FuncDefn + #[error("Source ({_1}) and target ({_2}) both contained FuncDefn with same public name {_0}")] + MultipleImpls(String, SN, TN), + /// Source and target containing public functions with conflicting signatures + // TODO ALAN Should we indicate which were decls or defns? via an extra enum? + #[error( + "Conflicting signatures for name {name} - Source ({src_node}) has {src_sig}, Target ({tgt_node}) has ({tgt_sig})" + )] + #[allow(missing_docs)] + SignatureConflict { + name: String, + src_node: SN, + src_sig: Box, + tgt_node: TN, + tgt_sig: Box, + }, + /// A [Visibility::Public] function in the source, whose body is being added + /// to the target, contained the entrypoint (which needs to be added + /// in a different place). + /// + /// [Visibility::Public]: crate::Visibility::Public + #[error("The entrypoint is contained within function {_0} which will be added as {_1:?}")] + AddFunctionContainingEntrypoint(SN, NodeLinkingDirective), +} + +impl NameLinkingPolicy { + /// Makes a new instance that specifies to handle + /// [signature conflicts](Self::on_signature_conflict) by failing with an error and + /// multiple [FuncDefn]s according to `multi_impls`. + /// + /// [FuncDefn]: crate::ops::FuncDefn + pub fn err_on_conflict(multi_impls: impl Into) -> Self { + Self { + multi_impls: multi_impls.into(), + sig_conflict: NewFuncHandling::RaiseError, + } + } + + /// Makes a new instance that specifies to keep both decls/defns when (for the same name) + /// they have different signatures or when both are defns. Thus, an error is never raised; + /// a (potentially-invalid) Hugr is always produced. + pub fn keep_both_invalid() -> Self { + Self { + multi_impls: MultipleImplHandling::NewFunc(NewFuncHandling::Add), + sig_conflict: NewFuncHandling::Add, + } + } + + /// Sets how to behave when both target and inserted Hugr have a + /// [Public] function with the same name but different signatures. + /// + /// [Public]: crate::Visibility::Public + pub fn on_signature_conflict(mut self, sc: NewFuncHandling) -> Self { + self.sig_conflict = sc; + self + } + + /// Tells how to behave when both target and inserted Hugr have a + /// [Public] function with the same name but different signatures. + /// + /// [Public]: crate::Visibility::Public + pub fn get_signature_conflict(&self) -> NewFuncHandling { + self.sig_conflict + } + + /// Sets how to behave when both target and inserted Hugr have a + /// [FuncDefn](crate::ops::FuncDefn) with the same name and signature. + pub fn on_multiple_impls(mut self, mih: MultipleImplHandling) -> Self { + self.multi_impls = mih; + self + } + + /// Tells how to behave when both target and inserted Hugr have a + /// [FuncDefn](crate::ops::FuncDefn) with the same name and signature. + pub fn get_multiple_impls(&self) -> MultipleImplHandling { + self.multi_impls + } + + /// Computes how this policy will act on a specified source (inserted) and target + /// (host) Hugr. + #[allow(clippy::type_complexity)] + pub fn to_node_linking( + &self, + target: &T, + source: &S, + ) -> Result, NameLinkingError> { + let existing = gather_existing(target); + let mut res = LinkActions::new(); + + let NameLinkingPolicy { + sig_conflict, + multi_impls, + } = self; + for n in source.children(source.module_root()) { + let dirv = match link_sig(source, n) { + None => continue, + Some(LinkSig::Private) => NodeLinkingDirective::add(), + Some(LinkSig::Public { name, is_defn, sig }) => { + if let Some((ex_ns, ex_sig)) = existing.get(name) { + match *sig_conflict { + _ if sig == *ex_sig => directive(name, n, is_defn, ex_ns, multi_impls)?, + NewFuncHandling::RaiseError => { + return Err(NameLinkingError::SignatureConflict { + name: name.clone(), + src_node: n, + src_sig: Box::new(sig.clone()), + tgt_node: *ex_ns.as_ref().left_or_else(|(n, _)| n), + tgt_sig: Box::new((*ex_sig).clone()), + }); + } + NewFuncHandling::Add => NodeLinkingDirective::add(), + } + } else { + NodeLinkingDirective::add() + } + } + }; + res.insert(n, dirv.into()); + } + + Ok(res) + } +} + +impl Default for NameLinkingPolicy { + fn default() -> Self { + Self::err_on_conflict(NewFuncHandling::RaiseError) + } +} + +fn directive( + name: &str, + new_n: SN, + new_defn: bool, + ex_ns: &Either)>, + multi_impls: &MultipleImplHandling, +) -> Result, NameLinkingError> { + Ok(match (new_defn, ex_ns) { + (false, Either::Right(_)) => NodeLinkingDirective::add(), // another alias + (false, Either::Left(defn)) => NodeLinkingDirective::UseExisting(*defn), // resolve decl + (true, Either::Right((decl, decls))) => { + NodeLinkingDirective::replace(std::iter::once(decl).chain(decls).cloned()) + } + (true, &Either::Left(defn)) => match multi_impls { + MultipleImplHandling::UseExisting => NodeLinkingDirective::UseExisting(defn), + MultipleImplHandling::UseNew => NodeLinkingDirective::replace([defn]), + MultipleImplHandling::NewFunc(NewFuncHandling::RaiseError) => { + return Err(NameLinkingError::MultipleImpls( + name.to_owned(), + new_n, + defn, + )); + } + MultipleImplHandling::NewFunc(NewFuncHandling::Add) => NodeLinkingDirective::add(), + }, + }) +} + +type PubFuncs<'a, N> = (Either)>, &'a PolyFuncType); + +enum LinkSig<'a> { + Private, + Public { + name: &'a String, + is_defn: bool, + sig: &'a PolyFuncType, + }, +} + +fn link_sig(h: &H, n: H::Node) -> Option> { + let (name, is_defn, vis, sig) = match h.get_optype(n) { + OpType::FuncDecl(fd) => (fd.func_name(), false, fd.visibility(), fd.signature()), + OpType::FuncDefn(fd) => (fd.func_name(), true, fd.visibility(), fd.signature()), + OpType::Const(_) => return Some(LinkSig::Private), + _ => return None, + }; + Some(match vis { + Visibility::Public => LinkSig::Public { name, is_defn, sig }, + Visibility::Private => LinkSig::Private, + }) +} + +fn gather_existing<'a, H: HugrView + ?Sized>( + h: &'a H, +) -> HashMap<&'a String, PubFuncs<'a, H::Node>> { + let left_if = |b| if b { Either::Left } else { Either::Right }; + h.children(h.module_root()) + .filter_map(|n| { + link_sig(h, n).and_then(|link_sig| match link_sig { + LinkSig::Public { name, is_defn, sig } => Some((name, (left_if(is_defn)(n), sig))), + LinkSig::Private => None, + }) + }) + .into_grouping_map() + .aggregate(|acc: Option>, name, (new, sig2)| { + let Some((mut acc, sig1)) = acc else { + return Some((new.map_right(|n| (n, vec![])), sig2)); + }; + assert_eq!(sig1, sig2, "Invalid Hugr: different signatures for {name}"); + let (Either::Right((_, decls)), Either::Right(ndecl)) = (&mut acc, &new) else { + let err = match acc.is_left() && new.is_left() { + true => "Multiple FuncDefns", + false => "FuncDefn and FuncDecl(s)", + }; + panic!("Invalid Hugr: {err} for {name}"); + }; + decls.push(*ndecl); + Some((acc, sig2)) + }) +} + /// Details, node-by-node, how module-children of a source Hugr should be inserted into a /// target Hugr. /// /// For use with [HugrLinking::insert_link_hugr_by_node] and [HugrLinking::insert_link_view_by_node]. pub type NodeLinkingDirectives = HashMap>; +/// Details a concrete action to link a specific node from source Hugr into a specific target Hugr. +/// +/// A separate enum from [NodeLinkingDirective] to allow [NameLinkingPolicy::to_node_linking] +/// to specify a greater range of actions than that supported by +/// [HugrLinking::insert_link_hugr_by_node] and [HugrLinking::insert_link_view_by_node]. +#[derive(Clone, Debug, Hash, PartialEq, Eq, derive_more::From)] +#[non_exhaustive] +pub enum LinkAction { + /// Just apply the specified [NodeLinkingDirective]. + LinkNode(#[from] NodeLinkingDirective), +} + +/// Details the concrete actions to implement a specific source Hugr into a specific target Hugr. +/// +/// Computed from a [NameLinkingPolicy] and contains all actions required to implement +/// that policy (for those specific Hugrs). +pub type LinkActions = HashMap>; + /// Invariant: no SourceNode can be in both maps (by type of [NodeLinkingDirective]) /// TargetNodes can be (in RHS of multiple directives) struct Transfers { @@ -287,12 +634,24 @@ mod test { use cool_asserts::assert_matches; use itertools::Itertools; + use rstest::rstest; use super::{HugrLinking, NodeLinkingDirective, NodeLinkingError}; use crate::builder::test::{dfg_calling_defn_decl, simple_dfg_hugr}; + use crate::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use crate::extension::prelude::{ConstUsize, usize_t}; use crate::hugr::hugrmut::test::check_calls_defn_decl; - use crate::ops::{FuncDecl, OpTag, OpTrait, handle::NodeHandle}; - use crate::{HugrView, hugr::HugrMut, types::Signature}; + use crate::hugr::linking::{ + MultipleImplHandling, NameLinkingError, NameLinkingPolicy, NewFuncHandling, + }; + use crate::hugr::{ValidationError, hugrmut::HugrMut}; + use crate::ops::{FuncDecl, OpTag, OpTrait, OpType, Value, handle::NodeHandle}; + use crate::std_extensions::arithmetic::int_ops::IntOpDef; + use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; + use crate::{Hugr, HugrView, Visibility, types::Signature}; #[test] fn test_insert_link_nodes_add() { @@ -522,4 +881,219 @@ mod test { assert_eq!(h.static_source(call), Some(defn)); } } + + #[allow(clippy::type_complexity)] + fn list_decls_defns(h: &H) -> (HashMap, HashMap) { + let mut decls = HashMap::new(); + let mut defns = HashMap::new(); + for n in h.children(h.module_root()) { + match h.get_optype(n) { + OpType::FuncDecl(fd) => decls.insert(n, fd.func_name().as_str()), + OpType::FuncDefn(fd) => defns.insert(n, fd.func_name().as_str()), + _ => None, + }; + } + (decls, defns) + } + + fn call_targets(h: &H) -> HashMap { + h.nodes() + .filter(|n| h.get_optype(*n).is_call()) + .map(|n| (n, h.static_source(n).unwrap())) + .collect() + } + + #[rstest] + fn combines_decls_defn( + #[values(NewFuncHandling::RaiseError, NewFuncHandling::Add)] sig_conflict: NewFuncHandling, + #[values( + NewFuncHandling::RaiseError.into(), + MultipleImplHandling::UseNew, + MultipleImplHandling::UseExisting, + NewFuncHandling::Add.into() + )] + multi_impls: MultipleImplHandling, + ) { + let i64_t = || INT_TYPES[6].to_owned(); + let foo_sig = Signature::new_endo(i64_t()); + let bar_sig = Signature::new(vec![i64_t(); 2], i64_t()); + let mut target = { + let mut fb = + FunctionBuilder::new_vis("foo", foo_sig.clone(), Visibility::Public).unwrap(); + let mut mb = fb.module_root_builder(); + let bar1 = mb.declare("bar", bar_sig.clone().into()).unwrap(); + let bar2 = mb.declare("bar", bar_sig.clone().into()).unwrap(); // alias + let [i] = fb.input_wires_arr(); + let [c] = fb.call(&bar1, &[], [i, i]).unwrap().outputs_arr(); + let r = fb.call(&bar2, &[], [i, c]).unwrap(); + let h = fb.finish_hugr_with_outputs(r.outputs()).unwrap(); + assert_eq!( + list_decls_defns(&h), + ( + HashMap::from([(bar1.node(), "bar"), (bar2.node(), "bar")]), + HashMap::from([(h.entrypoint(), "foo")]) + ) + ); + h + }; + + let inserted = { + let mut main_b = FunctionBuilder::new("main", Signature::new(vec![], i64_t())).unwrap(); + let mut mb = main_b.module_root_builder(); + let foo1 = mb.declare("foo", foo_sig.clone().into()).unwrap(); + let foo2 = mb.declare("foo", foo_sig.clone().into()).unwrap(); + let mut bar = mb + .define_function_vis("bar", bar_sig.clone(), Visibility::Public) + .unwrap(); + let res = bar + .add_dataflow_op(IntOpDef::iadd.with_log_width(6), bar.input_wires()) + .unwrap(); + let bar = bar.finish_with_outputs(res.outputs()).unwrap(); + let i = main_b.add_load_value(ConstInt::new_u(6, 257).unwrap()); + let c = main_b.call(&foo1, &[], [i]).unwrap(); + let r = main_b.call(&foo2, &[], c.outputs()).unwrap(); + let h = main_b.finish_hugr_with_outputs(r.outputs()).unwrap(); + assert_eq!( + list_decls_defns(&h), + ( + HashMap::from([(foo1.node(), "foo"), (foo2.node(), "foo")]), + HashMap::from([(h.entrypoint(), "main"), (bar.node(), "bar")]) + ) + ); + h + }; + + let pol = NameLinkingPolicy { + sig_conflict, + multi_impls, + }; + let mut target2 = target.clone(); + + target.link_module_view(&inserted, &pol).unwrap(); + target2.link_module(inserted, &pol).unwrap(); + for tgt in [target, target2] { + tgt.validate().unwrap(); + let (decls, defns) = list_decls_defns(&tgt); + assert_eq!(decls, HashMap::new()); + assert_eq!( + defns.values().copied().sorted().collect_vec(), + ["bar", "foo", "main"] + ); + let call_tgts = call_targets(&tgt); + for (defn, name) in defns { + if name != "main" { + // Defns now have two calls each (was one to each alias) + assert_eq!(call_tgts.values().filter(|tgt| **tgt == defn).count(), 2); + } + } + } + } + + #[rstest] + fn sig_conflict( + #[values(false, true)] host_defn: bool, + #[values(false, true)] inserted_defn: bool, + ) { + let mk_def_or_decl = |n, sig: Signature, defn| { + let mut mb = ModuleBuilder::new(); + let node = if defn { + let fb = mb.define_function_vis(n, sig, Visibility::Public).unwrap(); + let ins = fb.input_wires(); + fb.finish_with_outputs(ins).unwrap().node() + } else { + mb.declare(n, sig.into()).unwrap().node() + }; + (mb.finish_hugr().unwrap(), node) + }; + + let old_sig = Signature::new_endo(usize_t()); + let (orig_host, orig_fn) = mk_def_or_decl("foo", old_sig.clone(), host_defn); + let new_sig = Signature::new_endo(INT_TYPES[3].clone()); + let (inserted, inserted_fn) = mk_def_or_decl("foo", new_sig.clone(), inserted_defn); + + let pol = NameLinkingPolicy::err_on_conflict(NewFuncHandling::RaiseError); + let mut host = orig_host.clone(); + let res = host.link_module_view(&inserted, &pol); + assert_eq!(host, orig_host); // Did nothing + assert_eq!( + res.err(), + Some(NameLinkingError::SignatureConflict { + name: "foo".to_string(), + src_node: inserted_fn, + src_sig: Box::new(new_sig.into()), + tgt_node: orig_fn, + tgt_sig: Box::new(old_sig.into()) + }) + ); + + let pol = pol.on_signature_conflict(NewFuncHandling::Add); + let node_map = host.link_module(inserted, &pol).unwrap().node_map; + assert_eq!( + host.validate(), + Err(ValidationError::DuplicateExport { + link_name: "foo".to_string(), + children: [orig_fn, node_map[&inserted_fn]] + }) + ); + } + + #[rstest] + #[case(MultipleImplHandling::UseNew, vec![11])] + #[case(MultipleImplHandling::UseExisting, vec![5])] + #[case(NewFuncHandling::Add.into(), vec![5, 11])] + #[case(NewFuncHandling::RaiseError.into(), vec![])] + fn impl_conflict(#[case] multi_impls: MultipleImplHandling, #[case] expected: Vec) { + fn build_hugr(cst: u64) -> Hugr { + let mut mb = ModuleBuilder::new(); + let cst = mb.add_constant(Value::from(ConstUsize::new(cst))); + let mut fb = mb + .define_function_vis("foo", Signature::new(vec![], usize_t()), Visibility::Public) + .unwrap(); + let c = fb.load_const(&cst); + fb.finish_with_outputs([c]).unwrap(); + mb.finish_hugr().unwrap() + } + let backup = build_hugr(5); + let mut host = backup.clone(); + let inserted = build_hugr(11); + + let pol = NameLinkingPolicy::keep_both_invalid().on_multiple_impls(multi_impls); + let res = host.link_module(inserted, &pol); + if multi_impls == NewFuncHandling::RaiseError.into() { + assert!(matches!(res, Err(NameLinkingError::MultipleImpls(n, _, _)) if n == "foo")); + assert_eq!(host, backup); + return; + } + res.unwrap(); + let val_res = host.validate(); + if multi_impls == NewFuncHandling::Add.into() { + assert!( + matches!(val_res, Err(ValidationError::DuplicateExport { link_name, .. }) if link_name == "foo") + ); + } else { + val_res.unwrap(); + } + let func_consts = host + .children(host.module_root()) + .filter(|n| host.get_optype(*n).is_func_defn()) + .map(|n| { + host.children(n) + .filter_map(|ch| host.static_source(ch)) // LoadConstant's + .map(|c| host.get_optype(c).as_const().unwrap()) + .map(|c| c.get_custom_value::().unwrap().value()) + .exactly_one() + .ok() + .unwrap() + }) + .collect_vec(); + assert_eq!(func_consts, expected); + // At the moment we copy all the constants regardless of whether they are used: + let all_consts: Vec<_> = host + .children(host.module_root()) + .filter_map(|ch| host.get_optype(ch).as_const()) + .map(|c| c.get_custom_value::().unwrap().value()) + .sorted() + .collect(); + assert_eq!(all_consts, [5, 11]); + } }