@@ -107,9 +107,9 @@ impl NodeTemplate {
107107 }
108108 }
109109
110- fn replace ( & self , hugr : & mut impl HugrMut < Node = Node > , n : Node ) -> Result < ( ) , BuildError > {
110+ fn replace ( self , hugr : & mut impl HugrMut < Node = Node > , n : Node ) -> Result < ( ) , BuildError > {
111111 assert_eq ! ( hugr. children( n) . count( ) , 0 ) ;
112- let new_optype = match self . clone ( ) {
112+ let new_optype = match self {
113113 NodeTemplate :: SingleOp ( op_type) => op_type,
114114 NodeTemplate :: CompoundOp ( new_h) => {
115115 let new_entrypoint = hugr. insert_hugr ( n, * new_h) . inserted_entrypoint ;
@@ -171,6 +171,23 @@ fn call<H: HugrView<Node = Node>>(
171171 Ok ( Call :: try_new ( func_sig, type_args) ?)
172172}
173173
174+ /// Options for how the replacement for an op is processed. May be specified by
175+ /// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with].
176+ /// Otherwise (the default), replacements are inserted as is (without further processing).
177+ #[ derive( Clone , Default , PartialEq , Eq ) ] // More derives might inhibit future extension
178+ pub struct ReplacementOptions {
179+ linearize : bool ,
180+ }
181+
182+ impl ReplacementOptions {
183+ /// Specifies that all operations within the replacement should have their
184+ /// output ports linearized.
185+ pub fn with_linearization ( mut self , lin : bool ) -> Self {
186+ self . linearize = lin;
187+ self
188+ }
189+ }
190+
174191/// A configuration of what types, ops, and constants should be replaced with what.
175192/// May be applied to a Hugr via [`Self::run`].
176193///
@@ -203,8 +220,14 @@ pub struct ReplaceTypes {
203220 type_map : HashMap < CustomType , Type > ,
204221 param_types : HashMap < ParametricType , Arc < dyn Fn ( & [ TypeArg ] ) -> Option < Type > > > ,
205222 linearize : DelegatingLinearizer ,
206- op_map : HashMap < OpHashWrapper , NodeTemplate > ,
207- param_ops : HashMap < ParametricOp , Arc < dyn Fn ( & [ TypeArg ] ) -> Option < NodeTemplate > > > ,
223+ op_map : HashMap < OpHashWrapper , ( NodeTemplate , ReplacementOptions ) > ,
224+ param_ops : HashMap <
225+ ParametricOp ,
226+ (
227+ Arc < dyn Fn ( & [ TypeArg ] ) -> Option < NodeTemplate > > ,
228+ ReplacementOptions ,
229+ ) ,
230+ > ,
208231 consts : HashMap <
209232 CustomType ,
210233 Arc < dyn Fn ( & OpaqueValue , & ReplaceTypes ) -> Result < Value , ReplaceTypesError > > ,
@@ -337,13 +360,36 @@ impl ReplaceTypes {
337360 }
338361
339362 /// Configures this instance to change occurrences of `src` to `dest`.
363+ /// Equivalent to [Self::replace_op_with] with default [ReplacementOptions].
364+ pub fn replace_op ( & mut self , src : & ExtensionOp , dest : NodeTemplate ) {
365+ self . replace_op_with ( src, dest, ReplacementOptions :: default ( ) )
366+ }
367+
368+ /// Configures this instance to change occurrences of `src` to `dest`.
369+ ///
340370 /// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes
341371 /// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus,
342372 /// this should only be used on already-*[monomorphize](super::monomorphize())d*
343373 /// Hugrs, as substitution (parametric polymorphism) happening later will not respect
344374 /// this replacement.
345- pub fn replace_op ( & mut self , src : & ExtensionOp , dest : NodeTemplate ) {
346- self . op_map . insert ( OpHashWrapper :: from ( src) , dest) ;
375+ pub fn replace_op_with (
376+ & mut self ,
377+ src : & ExtensionOp ,
378+ dest : NodeTemplate ,
379+ opts : ReplacementOptions ,
380+ ) {
381+ self . op_map . insert ( OpHashWrapper :: from ( src) , ( dest, opts) ) ;
382+ }
383+
384+ /// Configures this instance to change occurrences of a parametrized op `src`
385+ /// via a callback that builds the replacement type given the [`TypeArg`]s.
386+ /// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions].
387+ pub fn replace_parametrized_op (
388+ & mut self ,
389+ src : & OpDef ,
390+ dest_fn : impl Fn ( & [ TypeArg ] ) -> Option < NodeTemplate > + ' static ,
391+ ) {
392+ self . replace_parametrized_op_with ( src, dest_fn, ReplacementOptions :: default ( ) )
347393 }
348394
349395 /// Configures this instance to change occurrences of a parametrized op `src`
@@ -352,12 +398,13 @@ impl ReplaceTypes {
352398 /// fit the bounds of the original op).
353399 ///
354400 /// If the Callback returns None, the new typeargs will be applied to the original op.
355- pub fn replace_parametrized_op (
401+ pub fn replace_parametrized_op_with (
356402 & mut self ,
357403 src : & OpDef ,
358404 dest_fn : impl Fn ( & [ TypeArg ] ) -> Option < NodeTemplate > + ' static ,
405+ opts : ReplacementOptions ,
359406 ) {
360- self . param_ops . insert ( src. into ( ) , Arc :: new ( dest_fn) ) ;
407+ self . param_ops . insert ( src. into ( ) , ( Arc :: new ( dest_fn) , opts ) ) ;
361408 }
362409
363410 /// Configures this instance to change [Const]s of type `src_ty`, using
@@ -447,34 +494,40 @@ impl ReplaceTypes {
447494 | rest. transform ( self ) ?) ,
448495
449496 OpType :: Const ( Const { value, .. } ) => self . change_value ( value) ,
450- OpType :: ExtensionOp ( ext_op) => Ok (
451- // Copy/discard insertion done by caller
452- if let Some ( replacement) = self . op_map . get ( & OpHashWrapper :: from ( & * ext_op) ) {
497+ OpType :: ExtensionOp ( ext_op) => Ok ( {
498+ let def = ext_op. def_arc ( ) ;
499+ let mut changed = false ;
500+ let replacement = match self . op_map . get ( & OpHashWrapper :: from ( & * ext_op) ) {
501+ r @ Some ( _) => r. cloned ( ) ,
502+ None => {
503+ let mut args = ext_op. args ( ) . to_vec ( ) ;
504+ changed = args. transform ( self ) ?;
505+ let r2 = self
506+ . param_ops
507+ . get ( & def. as_ref ( ) . into ( ) )
508+ . and_then ( |( rep_fn, opts) | rep_fn ( & args) . map ( |nt| ( nt, opts. clone ( ) ) ) ) ;
509+ if r2. is_none ( ) && changed {
510+ * ext_op = ExtensionOp :: new ( def. clone ( ) , args) ?;
511+ }
512+ r2
513+ }
514+ } ;
515+ if let Some ( ( replacement, opts) ) = replacement {
453516 replacement
454517 . replace ( hugr, n)
455518 . map_err ( |e| ReplaceTypesError :: AddTemplateError ( n, Box :: new ( e) ) ) ?;
456- true
457- } else {
458- let def = ext_op. def_arc ( ) ;
459- let mut args = ext_op. args ( ) . to_vec ( ) ;
460- let ch = args. transform ( self ) ?;
461- if let Some ( replacement) = self
462- . param_ops
463- . get ( & def. as_ref ( ) . into ( ) )
464- . and_then ( |rep_fn| rep_fn ( & args) )
465- {
466- replacement
467- . replace ( hugr, n)
468- . map_err ( |e| ReplaceTypesError :: AddTemplateError ( n, Box :: new ( e) ) ) ?;
469- true
470- } else {
471- if ch {
472- * ext_op = ExtensionOp :: new ( def. clone ( ) , args) ?;
519+ if opts. linearize {
520+ for d in hugr. descendants ( n) . collect :: < Vec < _ > > ( ) {
521+ if d != n {
522+ self . linearize_outputs ( hugr, d) ?;
523+ }
473524 }
474- ch
475525 }
476- } ,
477- ) ,
526+ true
527+ } else {
528+ changed
529+ }
530+ } ) ,
478531
479532 OpType :: OpaqueOp ( _) => panic ! ( "OpaqueOp should not be in a Hugr" ) ,
480533
@@ -518,6 +571,27 @@ impl ReplaceTypes {
518571 Value :: Function { hugr } => self . run ( & mut * * hugr) ,
519572 }
520573 }
574+
575+ fn linearize_outputs < H : HugrMut < Node = Node > > (
576+ & self ,
577+ hugr : & mut H ,
578+ n : H :: Node ,
579+ ) -> Result < ( ) , LinearizeError > {
580+ if let Some ( new_sig) = hugr. get_optype ( n) . dataflow_signature ( ) {
581+ let new_sig = new_sig. into_owned ( ) ;
582+ for outp in new_sig. output_ports ( ) {
583+ if !new_sig. out_port_type ( outp) . unwrap ( ) . copyable ( ) {
584+ let targets = hugr. linked_inputs ( n, outp) . collect :: < Vec < _ > > ( ) ;
585+ if targets. len ( ) != 1 {
586+ hugr. disconnect ( n, outp) ;
587+ let src = Wire :: new ( n, outp) ;
588+ self . linearize . insert_copy_discard ( hugr, src, & targets) ?;
589+ }
590+ }
591+ }
592+ }
593+ Ok ( ( ) )
594+ }
521595}
522596
523597impl < H : HugrMut < Node = Node > > ComposablePass < H > for ReplaceTypes {
@@ -528,21 +602,8 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
528602 let mut changed = false ;
529603 for n in hugr. entry_descendants ( ) . collect :: < Vec < _ > > ( ) {
530604 changed |= self . change_node ( hugr, n) ?;
531- let new_dfsig = hugr. get_optype ( n) . dataflow_signature ( ) ;
532- if let Some ( new_sig) = new_dfsig
533- . filter ( |_| changed && n != hugr. entrypoint ( ) )
534- . map ( Cow :: into_owned)
535- {
536- for outp in new_sig. output_ports ( ) {
537- if !new_sig. out_port_type ( outp) . unwrap ( ) . copyable ( ) {
538- let targets = hugr. linked_inputs ( n, outp) . collect :: < Vec < _ > > ( ) ;
539- if targets. len ( ) != 1 {
540- hugr. disconnect ( n, outp) ;
541- let src = Wire :: new ( n, outp) ;
542- self . linearize . insert_copy_discard ( hugr, src, & targets) ?;
543- }
544- }
545- }
605+ if n != hugr. entrypoint ( ) && changed {
606+ self . linearize_outputs ( hugr, n) ?;
546607 }
547608 }
548609 Ok ( changed)
0 commit comments