Skip to content

Commit 2e11266

Browse files
authored
feat: ReplaceTypes allows linearizing inside Op replacements (#2435)
* Add `replace_parametrized_op_with` and `replace_op_with` which allow specifying a ReplacementOptions struct * For now, ReplacementOptions has only one option, to allow running the linearizer on all outports of the replacement * Test demonstrates - of course the replacement Hugr is invalid (unconnected port) before linearization.
1 parent d17b245 commit 2e11266

File tree

2 files changed

+179
-56
lines changed

2 files changed

+179
-56
lines changed

hugr-passes/src/replace_types.rs

Lines changed: 107 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ impl NodeTemplate {
107107
}
108108
}
109109

110-
fn replace(&self, hugr: &mut impl HugrMut<Node = Node>, n: Node) -> Result<(), BuildError> {
110+
fn replace(self, hugr: &mut impl HugrMut<Node = Node>, n: Node) -> Result<(), BuildError> {
111111
assert_eq!(hugr.children(n).count(), 0);
112-
let new_optype = match self.clone() {
112+
let new_optype = match self {
113113
NodeTemplate::SingleOp(op_type) => op_type,
114114
NodeTemplate::CompoundOp(new_h) => {
115115
let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint;
@@ -171,6 +171,23 @@ fn call<H: HugrView<Node = Node>>(
171171
Ok(Call::try_new(func_sig, type_args)?)
172172
}
173173

174+
/// Options for how the replacement for an op is processed. May be specified by
175+
/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with].
176+
/// Otherwise (the default), replacements are inserted as is (without further processing).
177+
#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension
178+
pub struct ReplacementOptions {
179+
linearize: bool,
180+
}
181+
182+
impl ReplacementOptions {
183+
/// Specifies that all operations within the replacement should have their
184+
/// output ports linearized.
185+
pub fn with_linearization(mut self, lin: bool) -> Self {
186+
self.linearize = lin;
187+
self
188+
}
189+
}
190+
174191
/// A configuration of what types, ops, and constants should be replaced with what.
175192
/// May be applied to a Hugr via [`Self::run`].
176193
///
@@ -203,8 +220,14 @@ pub struct ReplaceTypes {
203220
type_map: HashMap<CustomType, Type>,
204221
param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>,
205222
linearize: DelegatingLinearizer,
206-
op_map: HashMap<OpHashWrapper, NodeTemplate>,
207-
param_ops: HashMap<ParametricOp, Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>>,
223+
op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>,
224+
param_ops: HashMap<
225+
ParametricOp,
226+
(
227+
Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>,
228+
ReplacementOptions,
229+
),
230+
>,
208231
consts: HashMap<
209232
CustomType,
210233
Arc<dyn Fn(&OpaqueValue, &ReplaceTypes) -> Result<Value, ReplaceTypesError>>,
@@ -337,13 +360,36 @@ impl ReplaceTypes {
337360
}
338361

339362
/// Configures this instance to change occurrences of `src` to `dest`.
363+
/// Equivalent to [Self::replace_op_with] with default [ReplacementOptions].
364+
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
365+
self.replace_op_with(src, dest, ReplacementOptions::default())
366+
}
367+
368+
/// Configures this instance to change occurrences of `src` to `dest`.
369+
///
340370
/// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes
341371
/// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus,
342372
/// this should only be used on already-*[monomorphize](super::monomorphize())d*
343373
/// Hugrs, as substitution (parametric polymorphism) happening later will not respect
344374
/// this replacement.
345-
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
346-
self.op_map.insert(OpHashWrapper::from(src), dest);
375+
pub fn replace_op_with(
376+
&mut self,
377+
src: &ExtensionOp,
378+
dest: NodeTemplate,
379+
opts: ReplacementOptions,
380+
) {
381+
self.op_map.insert(OpHashWrapper::from(src), (dest, opts));
382+
}
383+
384+
/// Configures this instance to change occurrences of a parametrized op `src`
385+
/// via a callback that builds the replacement type given the [`TypeArg`]s.
386+
/// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions].
387+
pub fn replace_parametrized_op(
388+
&mut self,
389+
src: &OpDef,
390+
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
391+
) {
392+
self.replace_parametrized_op_with(src, dest_fn, ReplacementOptions::default())
347393
}
348394

349395
/// Configures this instance to change occurrences of a parametrized op `src`
@@ -352,12 +398,13 @@ impl ReplaceTypes {
352398
/// fit the bounds of the original op).
353399
///
354400
/// If the Callback returns None, the new typeargs will be applied to the original op.
355-
pub fn replace_parametrized_op(
401+
pub fn replace_parametrized_op_with(
356402
&mut self,
357403
src: &OpDef,
358404
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
405+
opts: ReplacementOptions,
359406
) {
360-
self.param_ops.insert(src.into(), Arc::new(dest_fn));
407+
self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts));
361408
}
362409

363410
/// Configures this instance to change [Const]s of type `src_ty`, using
@@ -447,34 +494,40 @@ impl ReplaceTypes {
447494
| rest.transform(self)?),
448495

449496
OpType::Const(Const { value, .. }) => self.change_value(value),
450-
OpType::ExtensionOp(ext_op) => Ok(
451-
// Copy/discard insertion done by caller
452-
if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
497+
OpType::ExtensionOp(ext_op) => Ok({
498+
let def = ext_op.def_arc();
499+
let mut changed = false;
500+
let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
501+
r @ Some(_) => r.cloned(),
502+
None => {
503+
let mut args = ext_op.args().to_vec();
504+
changed = args.transform(self)?;
505+
let r2 = self
506+
.param_ops
507+
.get(&def.as_ref().into())
508+
.and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone())));
509+
if r2.is_none() && changed {
510+
*ext_op = ExtensionOp::new(def.clone(), args)?;
511+
}
512+
r2
513+
}
514+
};
515+
if let Some((replacement, opts)) = replacement {
453516
replacement
454517
.replace(hugr, n)
455518
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
456-
true
457-
} else {
458-
let def = ext_op.def_arc();
459-
let mut args = ext_op.args().to_vec();
460-
let ch = args.transform(self)?;
461-
if let Some(replacement) = self
462-
.param_ops
463-
.get(&def.as_ref().into())
464-
.and_then(|rep_fn| rep_fn(&args))
465-
{
466-
replacement
467-
.replace(hugr, n)
468-
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
469-
true
470-
} else {
471-
if ch {
472-
*ext_op = ExtensionOp::new(def.clone(), args)?;
519+
if opts.linearize {
520+
for d in hugr.descendants(n).collect::<Vec<_>>() {
521+
if d != n {
522+
self.linearize_outputs(hugr, d)?;
523+
}
473524
}
474-
ch
475525
}
476-
},
477-
),
526+
true
527+
} else {
528+
changed
529+
}
530+
}),
478531

479532
OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"),
480533

@@ -518,6 +571,27 @@ impl ReplaceTypes {
518571
Value::Function { hugr } => self.run(&mut **hugr),
519572
}
520573
}
574+
575+
fn linearize_outputs<H: HugrMut<Node = Node>>(
576+
&self,
577+
hugr: &mut H,
578+
n: H::Node,
579+
) -> Result<(), LinearizeError> {
580+
if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() {
581+
let new_sig = new_sig.into_owned();
582+
for outp in new_sig.output_ports() {
583+
if !new_sig.out_port_type(outp).unwrap().copyable() {
584+
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
585+
if targets.len() != 1 {
586+
hugr.disconnect(n, outp);
587+
let src = Wire::new(n, outp);
588+
self.linearize.insert_copy_discard(hugr, src, &targets)?;
589+
}
590+
}
591+
}
592+
}
593+
Ok(())
594+
}
521595
}
522596

523597
impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
@@ -528,21 +602,8 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
528602
let mut changed = false;
529603
for n in hugr.entry_descendants().collect::<Vec<_>>() {
530604
changed |= self.change_node(hugr, n)?;
531-
let new_dfsig = hugr.get_optype(n).dataflow_signature();
532-
if let Some(new_sig) = new_dfsig
533-
.filter(|_| changed && n != hugr.entrypoint())
534-
.map(Cow::into_owned)
535-
{
536-
for outp in new_sig.output_ports() {
537-
if !new_sig.out_port_type(outp).unwrap().copyable() {
538-
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
539-
if targets.len() != 1 {
540-
hugr.disconnect(n, outp);
541-
let src = Wire::new(n, outp);
542-
self.linearize.insert_copy_discard(hugr, src, &targets)?;
543-
}
544-
}
545-
}
605+
if n != hugr.entrypoint() && changed {
606+
self.linearize_outputs(hugr, n)?;
546607
}
547608
}
548609
Ok(changed)

hugr-passes/src/replace_types/linearize.rs

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,10 @@ pub trait Linearizer {
5252
src: Wire,
5353
targets: &[(Node, IncomingPort)],
5454
) -> Result<(), LinearizeError> {
55-
let sig = hugr.signature(src.node()).unwrap();
56-
let typ = sig.port_type(src.source()).unwrap();
5755
let (tgt_node, tgt_inport) = if targets.len() == 1 {
5856
*targets.first().unwrap()
5957
} else {
60-
// Fail fast if the edges are nonlocal. (TODO transform to local edges!)
58+
// Fail fast if the edges are nonlocal.
6159
let src_parent = hugr
6260
.get_parent(src.node())
6361
.expect("Root node cannot have out edges");
@@ -74,7 +72,8 @@ pub trait Linearizer {
7472
tgt_parent,
7573
});
7674
}
77-
let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it
75+
let sig = hugr.signature(src.node()).unwrap();
76+
let typ = sig.port_type(src.source()).unwrap().clone();
7877
let copy_discard_op = self
7978
.copy_discard_op(&typ, targets.len())?
8079
.add_hugr(hugr, src_parent)
@@ -148,7 +147,8 @@ pub enum LinearizeError {
148147
sig: Option<Box<Signature>>,
149148
},
150149
#[error(
151-
"Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})"
150+
"Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}).
151+
Try using LocalizeEdges pass first."
152152
)]
153153
NoLinearNonLocalEdges {
154154
src: Node,
@@ -367,11 +367,11 @@ mod test {
367367
use std::sync::Arc;
368368

369369
use hugr_core::builder::{
370-
BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
371-
inout_sig,
370+
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
371+
HugrBuilder, inout_sig,
372372
};
373373

374-
use hugr_core::extension::prelude::{option_type, usize_t};
374+
use hugr_core::extension::prelude::{option_type, qb_t, usize_t};
375375
use hugr_core::extension::simple_op::MakeExtensionOp;
376376
use hugr_core::extension::{
377377
CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version,
@@ -385,14 +385,16 @@ mod test {
385385
};
386386
use hugr_core::types::type_param::TypeParam;
387387
use hugr_core::types::{
388-
FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow,
388+
FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow,
389389
};
390390
use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row};
391391
use itertools::Itertools;
392392
use rstest::rstest;
393393

394394
use crate::replace_types::handlers::linearize_value_array;
395-
use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError};
395+
use crate::replace_types::{
396+
LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions,
397+
};
396398
use crate::{ComposablePass, ReplaceTypes};
397399

398400
const LIN_T: &str = "Lin";
@@ -855,4 +857,64 @@ mod test {
855857
panic!("Expected error");
856858
}
857859
}
860+
861+
#[test]
862+
fn use_in_op_callback() {
863+
let (e, mut lowerer) = ext_lowerer();
864+
let drop_ext = Extension::new_arc(
865+
IdentList::new_unchecked("DropExt"),
866+
Version::new(0, 0, 0),
867+
|e, w| {
868+
e.add_op(
869+
"drop".into(),
870+
String::new(),
871+
PolyFuncTypeRV::new(
872+
[TypeBound::Linear.into()], // It won't *lower* for any type tho!
873+
Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]),
874+
),
875+
w,
876+
)
877+
.unwrap();
878+
},
879+
);
880+
let drop_op = drop_ext.get_op("drop").unwrap();
881+
lowerer.replace_parametrized_op_with(
882+
drop_op,
883+
|args| {
884+
let [TypeArg::Runtime(ty)] = args else {
885+
panic!("Expected just one type")
886+
};
887+
// The Hugr here is invalid, so we have to pull it out manually
888+
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
889+
let h = std::mem::take(dfb.hugr_mut());
890+
Some(NodeTemplate::CompoundOp(Box::new(h)))
891+
},
892+
ReplacementOptions::default().with_linearization(true),
893+
);
894+
895+
let build_hugr = |ty: Type| {
896+
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
897+
let [inp] = dfb.input_wires_arr();
898+
let drop_op = drop_ext
899+
.instantiate_extension_op("drop", [ty.into()])
900+
.unwrap();
901+
dfb.add_dataflow_op(drop_op, [inp]).unwrap();
902+
dfb.finish_hugr().unwrap()
903+
};
904+
// We can drop a tuple of 2* lin_t
905+
let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap());
906+
let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2]));
907+
lowerer.run(&mut h).unwrap();
908+
h.validate().unwrap();
909+
let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op());
910+
assert_eq!(exts.clone().count(), 2);
911+
assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard"));
912+
913+
// We cannot drop a qubit
914+
let mut h = build_hugr(qb_t());
915+
assert_eq!(
916+
lowerer.run(&mut h).unwrap_err(),
917+
ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t())))
918+
);
919+
}
858920
}

0 commit comments

Comments
 (0)