Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 107 additions & 46 deletions hugr-passes/src/replace_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ impl NodeTemplate {
}
}

fn replace(&self, hugr: &mut impl HugrMut<Node = Node>, n: Node) -> Result<(), BuildError> {
fn replace(self, hugr: &mut impl HugrMut<Node = Node>, n: Node) -> Result<(), BuildError> {
assert_eq!(hugr.children(n).count(), 0);
let new_optype = match self.clone() {
let new_optype = match self {
NodeTemplate::SingleOp(op_type) => op_type,
NodeTemplate::CompoundOp(new_h) => {
let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint;
Expand Down Expand Up @@ -171,6 +171,23 @@ fn call<H: HugrView<Node = Node>>(
Ok(Call::try_new(func_sig, type_args)?)
}

/// Options for how the replacement for an op is processed. May be specified by
/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with].
/// Otherwise (the default), replacements are inserted as is (without further processing).
#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension
pub struct ReplacementOptions {
linearize: bool,
}

impl ReplacementOptions {
/// Specifies that all operations within the replacement should have their
/// output ports linearized.
pub fn with_linearization(mut self, lin: bool) -> Self {
self.linearize = lin;
self
}
}

/// A configuration of what types, ops, and constants should be replaced with what.
/// May be applied to a Hugr via [`Self::run`].
///
Expand Down Expand Up @@ -203,8 +220,14 @@ pub struct ReplaceTypes {
type_map: HashMap<CustomType, Type>,
param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>,
linearize: DelegatingLinearizer,
op_map: HashMap<OpHashWrapper, NodeTemplate>,
param_ops: HashMap<ParametricOp, Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>>,
op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>,
param_ops: HashMap<
ParametricOp,
(
Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>,
ReplacementOptions,
),
>,
consts: HashMap<
CustomType,
Arc<dyn Fn(&OpaqueValue, &ReplaceTypes) -> Result<Value, ReplaceTypesError>>,
Expand Down Expand Up @@ -337,13 +360,36 @@ impl ReplaceTypes {
}

/// Configures this instance to change occurrences of `src` to `dest`.
/// Equivalent to [Self::replace_op_with] with default [ReplacementOptions].
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
self.replace_op_with(src, dest, ReplacementOptions::default())
}

/// Configures this instance to change occurrences of `src` to `dest`.
///
/// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes
/// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus,
/// this should only be used on already-*[monomorphize](super::monomorphize())d*
/// Hugrs, as substitution (parametric polymorphism) happening later will not respect
/// this replacement.
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
self.op_map.insert(OpHashWrapper::from(src), dest);
pub fn replace_op_with(
&mut self,
src: &ExtensionOp,
dest: NodeTemplate,
opts: ReplacementOptions,
) {
self.op_map.insert(OpHashWrapper::from(src), (dest, opts));
}

/// Configures this instance to change occurrences of a parametrized op `src`
/// via a callback that builds the replacement type given the [`TypeArg`]s.
/// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions].
pub fn replace_parametrized_op(
&mut self,
src: &OpDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
) {
self.replace_parametrized_op_with(src, dest_fn, ReplacementOptions::default())
}

/// Configures this instance to change occurrences of a parametrized op `src`
Expand All @@ -352,12 +398,13 @@ impl ReplaceTypes {
/// fit the bounds of the original op).
///
/// If the Callback returns None, the new typeargs will be applied to the original op.
pub fn replace_parametrized_op(
pub fn replace_parametrized_op_with(
&mut self,
src: &OpDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
opts: ReplacementOptions,
) {
self.param_ops.insert(src.into(), Arc::new(dest_fn));
self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts));
}

/// Configures this instance to change [Const]s of type `src_ty`, using
Expand Down Expand Up @@ -447,34 +494,40 @@ impl ReplaceTypes {
| rest.transform(self)?),

OpType::Const(Const { value, .. }) => self.change_value(value),
OpType::ExtensionOp(ext_op) => Ok(
// Copy/discard insertion done by caller
if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
OpType::ExtensionOp(ext_op) => Ok({
let def = ext_op.def_arc();
let mut changed = false;
let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
r @ Some(_) => r.cloned(),
None => {
let mut args = ext_op.args().to_vec();
changed = args.transform(self)?;
let r2 = self
.param_ops
.get(&def.as_ref().into())
.and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone())));
if r2.is_none() && changed {
*ext_op = ExtensionOp::new(def.clone(), args)?;
}
r2
}
};
if let Some((replacement, opts)) = replacement {
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
true
} else {
let def = ext_op.def_arc();
let mut args = ext_op.args().to_vec();
let ch = args.transform(self)?;
if let Some(replacement) = self
.param_ops
.get(&def.as_ref().into())
.and_then(|rep_fn| rep_fn(&args))
{
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
true
} else {
if ch {
*ext_op = ExtensionOp::new(def.clone(), args)?;
if opts.linearize {
for d in hugr.descendants(n).collect::<Vec<_>>() {
if d != n {
self.linearize_outputs(hugr, d)?;
}
}
ch
}
},
),
true
} else {
changed
}
}),

OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"),

Expand Down Expand Up @@ -518,6 +571,27 @@ impl ReplaceTypes {
Value::Function { hugr } => self.run(&mut **hugr),
}
}

fn linearize_outputs<H: HugrMut<Node = Node>>(
&self,
hugr: &mut H,
n: H::Node,
) -> Result<(), LinearizeError> {
if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() {
let new_sig = new_sig.into_owned();
for outp in new_sig.output_ports() {
if !new_sig.out_port_type(outp).unwrap().copyable() {
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
if targets.len() != 1 {
hugr.disconnect(n, outp);
let src = Wire::new(n, outp);
self.linearize.insert_copy_discard(hugr, src, &targets)?;
}
}
}
}
Ok(())
}
}

impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
Expand All @@ -528,21 +602,8 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
let mut changed = false;
for n in hugr.entry_descendants().collect::<Vec<_>>() {
changed |= self.change_node(hugr, n)?;
let new_dfsig = hugr.get_optype(n).dataflow_signature();
if let Some(new_sig) = new_dfsig
.filter(|_| changed && n != hugr.entrypoint())
.map(Cow::into_owned)
{
for outp in new_sig.output_ports() {
if !new_sig.out_port_type(outp).unwrap().copyable() {
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
if targets.len() != 1 {
hugr.disconnect(n, outp);
let src = Wire::new(n, outp);
self.linearize.insert_copy_discard(hugr, src, &targets)?;
}
}
}
if n != hugr.entrypoint() && changed {
self.linearize_outputs(hugr, n)?;
}
}
Ok(changed)
Expand Down
82 changes: 72 additions & 10 deletions hugr-passes/src/replace_types/linearize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ pub trait Linearizer {
src: Wire,
targets: &[(Node, IncomingPort)],
) -> Result<(), LinearizeError> {
let sig = hugr.signature(src.node()).unwrap();
let typ = sig.port_type(src.source()).unwrap();
let (tgt_node, tgt_inport) = if targets.len() == 1 {
*targets.first().unwrap()
} else {
// Fail fast if the edges are nonlocal. (TODO transform to local edges!)
// Fail fast if the edges are nonlocal.
let src_parent = hugr
.get_parent(src.node())
.expect("Root node cannot have out edges");
Expand All @@ -74,7 +72,8 @@ pub trait Linearizer {
tgt_parent,
});
}
let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it
let sig = hugr.signature(src.node()).unwrap();
let typ = sig.port_type(src.source()).unwrap().clone();
let copy_discard_op = self
.copy_discard_op(&typ, targets.len())?
.add_hugr(hugr, src_parent)
Expand Down Expand Up @@ -148,7 +147,8 @@ pub enum LinearizeError {
sig: Option<Box<Signature>>,
},
#[error(
"Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})"
"Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}).
Try using LocalizeEdges pass first."
)]
NoLinearNonLocalEdges {
src: Node,
Expand Down Expand Up @@ -367,11 +367,11 @@ mod test {
use std::sync::Arc;

use hugr_core::builder::{
BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
inout_sig,
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
HugrBuilder, inout_sig,
};

use hugr_core::extension::prelude::{option_type, usize_t};
use hugr_core::extension::prelude::{option_type, qb_t, usize_t};
use hugr_core::extension::simple_op::MakeExtensionOp;
use hugr_core::extension::{
CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version,
Expand All @@ -385,14 +385,16 @@ mod test {
};
use hugr_core::types::type_param::TypeParam;
use hugr_core::types::{
FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow,
FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow,
};
use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row};
use itertools::Itertools;
use rstest::rstest;

use crate::replace_types::handlers::linearize_value_array;
use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError};
use crate::replace_types::{
LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions,
};
use crate::{ComposablePass, ReplaceTypes};

const LIN_T: &str = "Lin";
Expand Down Expand Up @@ -855,4 +857,64 @@ mod test {
panic!("Expected error");
}
}

#[test]
fn use_in_op_callback() {
let (e, mut lowerer) = ext_lowerer();
let drop_ext = Extension::new_arc(
IdentList::new_unchecked("DropExt"),
Version::new(0, 0, 0),
|e, w| {
e.add_op(
"drop".into(),
String::new(),
PolyFuncTypeRV::new(
[TypeBound::Linear.into()], // It won't *lower* for any type tho!
Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]),
),
w,
)
.unwrap();
},
);
let drop_op = drop_ext.get_op("drop").unwrap();
lowerer.replace_parametrized_op_with(
drop_op,
|args| {
let [TypeArg::Runtime(ty)] = args else {
panic!("Expected just one type")
};
// The Hugr here is invalid, so we have to pull it out manually
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
let h = std::mem::take(dfb.hugr_mut());
Comment on lines +887 to +889
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this feels a bit odd, however, I see your point about the explicit version becoming quite cumbersome if we want to compose stuff. Overall, the approach in this PR feels like the better way to go 👍

Some(NodeTemplate::CompoundOp(Box::new(h)))
},
ReplacementOptions::default().with_linearization(true),
);

let build_hugr = |ty: Type| {
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
let [inp] = dfb.input_wires_arr();
let drop_op = drop_ext
.instantiate_extension_op("drop", [ty.into()])
.unwrap();
dfb.add_dataflow_op(drop_op, [inp]).unwrap();
dfb.finish_hugr().unwrap()
};
// We can drop a tuple of 2* lin_t
let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap());
let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2]));
lowerer.run(&mut h).unwrap();
h.validate().unwrap();
let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op());
assert_eq!(exts.clone().count(), 2);
assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard"));

// We cannot drop a qubit
let mut h = build_hugr(qb_t());
assert_eq!(
lowerer.run(&mut h).unwrap_err(),
ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t())))
);
}
}
Loading