diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 6acd91c13f..0b5cca8f6a 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -107,9 +107,9 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { + fn replace(self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); - let new_optype = match self.clone() { + let new_optype = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(new_h) => { let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint; @@ -171,6 +171,23 @@ fn call>( Ok(Call::try_new(func_sig, type_args)?) } +/// Options for how the replacement for an op is processed. May be specified by +/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. +/// Otherwise (the default), replacements are inserted as is (without further processing). +#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension +pub struct ReplacementOptions { + linearize: bool, +} + +impl ReplacementOptions { + /// Specifies that all operations within the replacement should have their + /// output ports linearized. + pub fn with_linearization(mut self, lin: bool) -> Self { + self.linearize = lin; + self + } +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [`Self::run`]. /// @@ -203,8 +220,14 @@ pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, linearize: DelegatingLinearizer, - op_map: HashMap, - param_ops: HashMap Option>>, + op_map: HashMap, + param_ops: HashMap< + ParametricOp, + ( + Arc Option>, + ReplacementOptions, + ), + >, consts: HashMap< CustomType, Arc Result>, @@ -337,13 +360,36 @@ impl ReplaceTypes { } /// Configures this instance to change occurrences of `src` to `dest`. + /// Equivalent to [Self::replace_op_with] with default [ReplacementOptions]. + pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { + self.replace_op_with(src, dest, ReplacementOptions::default()) + } + + /// Configures this instance to change occurrences of `src` to `dest`. + /// /// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes /// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus, /// this should only be used on already-*[monomorphize](super::monomorphize())d* /// Hugrs, as substitution (parametric polymorphism) happening later will not respect /// this replacement. - pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { - self.op_map.insert(OpHashWrapper::from(src), dest); + pub fn replace_op_with( + &mut self, + src: &ExtensionOp, + dest: NodeTemplate, + opts: ReplacementOptions, + ) { + self.op_map.insert(OpHashWrapper::from(src), (dest, opts)); + } + + /// Configures this instance to change occurrences of a parametrized op `src` + /// via a callback that builds the replacement type given the [`TypeArg`]s. + /// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions]. + pub fn replace_parametrized_op( + &mut self, + src: &OpDef, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + ) { + self.replace_parametrized_op_with(src, dest_fn, ReplacementOptions::default()) } /// Configures this instance to change occurrences of a parametrized op `src` @@ -352,12 +398,13 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. - pub fn replace_parametrized_op( + pub fn replace_parametrized_op_with( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + opts: ReplacementOptions, ) { - self.param_ops.insert(src.into(), Arc::new(dest_fn)); + self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts)); } /// Configures this instance to change [Const]s of type `src_ty`, using @@ -447,34 +494,40 @@ impl ReplaceTypes { | rest.transform(self)?), OpType::Const(Const { value, .. }) => self.change_value(value), - OpType::ExtensionOp(ext_op) => Ok( - // Copy/discard insertion done by caller - if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { + OpType::ExtensionOp(ext_op) => Ok({ + let def = ext_op.def_arc(); + let mut changed = false; + let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) { + r @ Some(_) => r.cloned(), + None => { + let mut args = ext_op.args().to_vec(); + changed = args.transform(self)?; + let r2 = self + .param_ops + .get(&def.as_ref().into()) + .and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone()))); + if r2.is_none() && changed { + *ext_op = ExtensionOp::new(def.clone(), args)?; + } + r2 + } + }; + if let Some((replacement, opts)) = replacement { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - true - } else { - let def = ext_op.def_arc(); - let mut args = ext_op.args().to_vec(); - let ch = args.transform(self)?; - if let Some(replacement) = self - .param_ops - .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args)) - { - replacement - .replace(hugr, n) - .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - true - } else { - if ch { - *ext_op = ExtensionOp::new(def.clone(), args)?; + if opts.linearize { + for d in hugr.descendants(n).collect::>() { + if d != n { + self.linearize_outputs(hugr, d)?; + } } - ch } - }, - ), + true + } else { + changed + } + }), OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), @@ -518,6 +571,27 @@ impl ReplaceTypes { Value::Function { hugr } => self.run(&mut **hugr), } } + + fn linearize_outputs>( + &self, + hugr: &mut H, + n: H::Node, + ) -> Result<(), LinearizeError> { + if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() { + let new_sig = new_sig.into_owned(); + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + Ok(()) + } } impl> ComposablePass for ReplaceTypes { @@ -528,21 +602,8 @@ impl> ComposablePass for ReplaceTypes { let mut changed = false; for n in hugr.entry_descendants().collect::>() { changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.entrypoint()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } + if n != hugr.entrypoint() && changed { + self.linearize_outputs(hugr, n)?; } } Ok(changed) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5cc8a64b66..0029683f47 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -52,12 +52,10 @@ pub trait Linearizer { src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { - let sig = hugr.signature(src.node()).unwrap(); - let typ = sig.port_type(src.source()).unwrap(); let (tgt_node, tgt_inport) = if targets.len() == 1 { *targets.first().unwrap() } else { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + // Fail fast if the edges are nonlocal. let src_parent = hugr .get_parent(src.node()) .expect("Root node cannot have out edges"); @@ -74,7 +72,8 @@ pub trait Linearizer { tgt_parent, }); } - let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap().clone(); let copy_discard_op = self .copy_discard_op(&typ, targets.len())? .add_hugr(hugr, src_parent) @@ -148,7 +147,8 @@ pub enum LinearizeError { sig: Option>, }, #[error( - "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})" + "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}). + Try using LocalizeEdges pass first." )] NoLinearNonLocalEdges { src: Node, @@ -367,11 +367,11 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - inout_sig, + BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, inout_sig, }; - use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{ CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, @@ -385,14 +385,16 @@ mod test { }; use hugr_core::types::type_param::TypeParam; use hugr_core::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow, + FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow, }; use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row}; use itertools::Itertools; use rstest::rstest; use crate::replace_types::handlers::linearize_value_array; - use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{ + LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions, + }; use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -855,4 +857,64 @@ mod test { panic!("Expected error"); } } + + #[test] + fn use_in_op_callback() { + let (e, mut lowerer) = ext_lowerer(); + let drop_ext = Extension::new_arc( + IdentList::new_unchecked("DropExt"), + Version::new(0, 0, 0), + |e, w| { + e.add_op( + "drop".into(), + String::new(), + PolyFuncTypeRV::new( + [TypeBound::Linear.into()], // It won't *lower* for any type tho! + Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]), + ), + w, + ) + .unwrap(); + }, + ); + let drop_op = drop_ext.get_op("drop").unwrap(); + lowerer.replace_parametrized_op_with( + drop_op, + |args| { + let [TypeArg::Runtime(ty)] = args else { + panic!("Expected just one type") + }; + // The Hugr here is invalid, so we have to pull it out manually + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let h = std::mem::take(dfb.hugr_mut()); + Some(NodeTemplate::CompoundOp(Box::new(h))) + }, + ReplacementOptions::default().with_linearization(true), + ); + + let build_hugr = |ty: Type| { + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let [inp] = dfb.input_wires_arr(); + let drop_op = drop_ext + .instantiate_extension_op("drop", [ty.into()]) + .unwrap(); + dfb.add_dataflow_op(drop_op, [inp]).unwrap(); + dfb.finish_hugr().unwrap() + }; + // We can drop a tuple of 2* lin_t + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2])); + lowerer.run(&mut h).unwrap(); + h.validate().unwrap(); + let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); + assert_eq!(exts.clone().count(), 2); + assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard")); + + // We cannot drop a qubit + let mut h = build_hugr(qb_t()); + assert_eq!( + lowerer.run(&mut h).unwrap_err(), + ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t()))) + ); + } }