Skip to content

Commit

Permalink
feat!: Resolve OpaqueOps and CustomType extensions (#1735)
Browse files Browse the repository at this point in the history
When we do a serialization roundtrip on a hugr, its custom optypes are
downgraded to `OpaqueOp`s and its `CustomType`s lose the weak link to
their extension.

This PR moves the pre-existing `OpaqueOp` resolution into a new
`crate::extension::resolution` module and expands it to also update the
custom type pointers. This can be run via the new
`Hugr::resolve_extension_defs` method (currently called by
`update_validate`).
In addition, we accumulate the exact list of extensions required to
define the hugr (this will be necessary for #1613).

Note that this will probably be no longer necessary after we stabilize
`hugr-model`, as we manually build the hugr when importing the model so
the extensions should already be present at that point.

In contrast to extension inference, this detects extensions needed to
_define_ a hugr (vs runtime requirements). The inference one will be
renamed as per #1734.

A big chunk of this PR is fixing tests that used `finish_prelude_hugr`
even though they required more extensions -.-'
Look at the first commit for the actual changes.

BREAKING CHANGE: Removed `resolve_opaque_op` and
`resolve_extension_ops`. Use `Hugr::resolve_extension_defs` instead.
  • Loading branch information
aborgna-q authored Dec 3, 2024
1 parent 4586db3 commit f5af79c
Show file tree
Hide file tree
Showing 35 changed files with 1,041 additions and 661 deletions.
4 changes: 2 additions & 2 deletions hugr-cli/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ impl ExtArgs {
pub fn run_dump(&self, registry: &ExtensionRegistry) {
let base_dir = &self.outdir;

for (name, ext) in registry.iter() {
for ext in registry.iter() {
let mut path = base_dir.clone();
for part in name.split('.') {
for part in ext.name().split('.') {
path.push(part);
}
path.set_extension("json");
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains(" Extension 'arithmetic.float.types' not found"));
.stderr(contains(" requires extension arithmetic.float.types"));
}

#[rstest]
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ pub(crate) mod test {
use crate::extension::prelude::{bool_t, usize_t};
use crate::hugr::{views::HugrView, HugrMut};
use crate::ops;
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::types::{PolyFuncType, Signature};
use crate::utils::test_quantum_extension;
use crate::Hugr;

use super::handle::BuildHandle;
Expand All @@ -269,7 +269,7 @@ pub(crate) mod test {

f(f_builder)?;

Ok(module_builder.finish_hugr(&FLOAT_OPS_REGISTRY)?)
Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?)
}

#[fixture]
Expand Down
60 changes: 36 additions & 24 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::builder::{Container, HugrBuilder, ModuleBuilder};
use crate::extension::prelude::{qb_t, usize_t};
use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
Expand Down Expand Up @@ -308,33 +309,44 @@ mod test {
.instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY)
.unwrap();

let build_res = build_main(
Signature::new(
vec![qb_t(), qb_t(), usize_t()],
vec![qb_t(), qb_t(), bool_t()],
let mut module_builder = ModuleBuilder::new();
let mut f_build = module_builder
.define_function(
"main",
Signature::new(
vec![qb_t(), qb_t(), usize_t()],
vec![qb_t(), qb_t(), bool_t()],
)
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
my_ext_name,
])),
)
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
my_ext_name,
]))
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
.unwrap();

let mut linear = f_build.as_circuit([q0, q1]);
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

let measure_out = linear
.append(cx_gate(), [0, 1])?
.append_and_consume(
my_custom_op,
[CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
)?
.append_with_outputs(measure(), [0])?;

let out_qbs = linear.finish();
f_build.finish_with_outputs(out_qbs.into_iter().chain(measure_out))
},
);
let mut linear = f_build.as_circuit([q0, q1]);

let measure_out = linear
.append(cx_gate(), [0, 1])
.unwrap()
.append_and_consume(
my_custom_op,
[CircuitUnit::Linear(0), CircuitUnit::Wire(angle)],
)
.unwrap()
.append_with_outputs(measure(), [0])
.unwrap();

let out_qbs = linear.finish();
f_build
.finish_with_outputs(out_qbs.into_iter().chain(measure_out))
.unwrap();

let mut registry = test_quantum_extension::REG.clone();
registry.register(my_ext).unwrap();
let build_res = module_builder.finish_hugr(&registry);

assert_matches!(build_res, Ok(_));
}
Expand Down
12 changes: 7 additions & 5 deletions hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub(crate) mod test {
use crate::std_extensions::logic::test::and_op;
use crate::types::type_param::TypeParam;
use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV};
use crate::utils::test_quantum_extension::h_gate;
use crate::utils::test_quantum_extension::{self, h_gate};
use crate::{builder::test::n_identity, type_row, Wire};

use super::super::test::simple_dfg_hugr;
Expand All @@ -342,8 +342,10 @@ pub(crate) mod test {
let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?;
let inner_id = n_identity(inner_builder)?;

outer_builder
.finish_prelude_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs()))
outer_builder.finish_hugr_with_outputs(
inner_id.outputs().chain(q_out.outputs()),
&test_quantum_extension::REG,
)
};

assert_eq!(build_result.err(), None);
Expand All @@ -361,7 +363,7 @@ pub(crate) mod test {

f(&mut builder)?;

builder.finish_hugr(&EMPTY_REG)
builder.finish_hugr(&test_quantum_extension::REG)
};
assert_matches!(build_result, Ok(_), "Failed on example: {}", msg);

Expand Down Expand Up @@ -583,7 +585,7 @@ pub(crate) mod test {

let add_c = add_c.finish_with_outputs(wires)?;
let [w] = add_c.outputs_arr();
parent.finish_hugr_with_outputs([w], &EMPTY_REG)?;
parent.finish_hugr_with_outputs([w], &test_quantum_extension::REG)?;

Ok(())
}
Expand Down
78 changes: 56 additions & 22 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,26 @@ use crate::types::RowVariable;
use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName};
use crate::types::{Signature, TypeNameRef};

mod const_fold;
mod op_def;
pub mod prelude;
pub mod resolution;
pub mod simple_op;
mod type_def;

pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder};
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod simple_op;
pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub use type_def::{TypeDef, TypeDefBound};

#[cfg(feature = "declarative")]
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);

impl ExtensionRegistry {
Expand Down Expand Up @@ -96,14 +98,15 @@ impl ExtensionRegistry {
}
}

/// Registers a new extension to the registry, keeping most up to date if extension exists.
/// Registers a new extension to the registry, keeping the one most up to
/// date if the extension already exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
/// If versions match, the original extension is kept. Returns a reference
/// to the registered extension if successful.
///
/// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see
/// [`ExtensionRegistry::register_updated_ref`].
/// Takes an Arc to the extension. To avoid cloning Arcs unless necessary,
/// see [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
Expand All @@ -118,8 +121,8 @@ impl ExtensionRegistry {
}
}

/// Registers a new extension to the registry, keeping most up to date if
/// extension exists.
/// Registers a new extension to the registry, keeping the one most up to
/// date if the extension already exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept. Returns a reference
Expand Down Expand Up @@ -151,8 +154,8 @@ impl ExtensionRegistry {
}

/// Returns an iterator over the extensions in the registry.
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Arc<Extension>)> {
self.0.iter()
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
self.0.values()
}

/// Returns an iterator over the extensions ids in the registry.
Expand All @@ -167,12 +170,38 @@ impl ExtensionRegistry {
}

impl IntoIterator for ExtensionRegistry {
type Item = (ExtensionId, Arc<Extension>);
type Item = Arc<Extension>;

type IntoIter = <BTreeMap<ExtensionId, Arc<Extension>> as IntoIterator>::IntoIter;
type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
self.0.into_values()
}
}

impl<'a> IntoIterator for &'a ExtensionRegistry {
type Item = &'a Arc<Extension>;

type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.values()
}
}

impl<'a> Extend<&'a Arc<Extension>> for ExtensionRegistry {
fn extend<T: IntoIterator<Item = &'a Arc<Extension>>>(&mut self, iter: T) {
for ext in iter {
self.register_updated_ref(ext);
}
}
}

impl Extend<Arc<Extension>> for ExtensionRegistry {
fn extend<T: IntoIterator<Item = Arc<Extension>>>(&mut self, iter: T) {
for ext in iter {
self.register_updated(ext);
}
}
}

Expand All @@ -197,8 +226,13 @@ pub enum SignatureError {
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
/// The Extension Registry did not contain an Extension referenced by the Signature
#[error("Extension '{0}' not found")]
ExtensionNotFound(ExtensionId),
#[error("Extension '{missing}' not found. Available extensions: {}",
available.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(", ")
)]
ExtensionNotFound {
missing: ExtensionId,
available: Vec<ExtensionId>,
},
/// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature
#[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName },
Expand Down Expand Up @@ -537,7 +571,7 @@ impl Extension {
ExtensionOp::new(op_def.clone(), args, ext_reg)
}

// Validates against a registry, which we can assume includes this extension itself.
/// Validates against a registry, which we can assume includes this extension itself.
// (TODO deal with the registry itself containing invalid extensions!)
fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624
Expand Down
18 changes: 6 additions & 12 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,9 @@ extensions:
let new_exts = new_extensions(&reg, dependencies).collect_vec();

assert_eq!(new_exts.len(), num_declarations);
assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types);
assert_eq!(
new_exts.iter().flat_map(|(_, e)| e.types()).count(),
num_types
);
assert_eq!(
new_exts.iter().flat_map(|(_, e)| e.operations()).count(),
new_exts.iter().flat_map(|e| e.operations()).count(),
num_operations
);
Ok(())
Expand All @@ -381,12 +378,9 @@ extensions:
let new_exts = new_extensions(&reg, dependencies).collect_vec();

assert_eq!(new_exts.len(), num_declarations);
assert_eq!(new_exts.iter().flat_map(|e| e.types()).count(), num_types);
assert_eq!(
new_exts.iter().flat_map(|(_, e)| e.types()).count(),
num_types
);
assert_eq!(
new_exts.iter().flat_map(|(_, e)| e.operations()).count(),
new_exts.iter().flat_map(|e| e.operations()).count(),
num_operations
);
Ok(())
Expand All @@ -413,8 +407,8 @@ extensions:
fn new_extensions<'a>(
reg: &'a ExtensionRegistry,
dependencies: &'a ExtensionRegistry,
) -> impl Iterator<Item = (&'a ExtensionId, &'a Arc<Extension>)> {
) -> impl Iterator<Item = &'a Arc<Extension>> {
reg.iter()
.filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID)
.filter(move |ext| !dependencies.contains(ext.name()) && ext.name() != &PRELUDE_ID)
}
}
4 changes: 3 additions & 1 deletion hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ mod test {
use crate::builder::inout_sig;
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64};
use crate::utils::test_quantum_extension;
use crate::{
builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Expand Down Expand Up @@ -1150,7 +1151,8 @@ mod test {
.add_dataflow_op(panic_op, [err, q0, q1])
.unwrap()
.outputs_arr();
b.finish_prelude_hugr_with_outputs([q0, q1]).unwrap();
b.finish_hugr_with_outputs([q0, q1], &test_quantum_extension::REG)
.unwrap();
}

#[test]
Expand Down
Loading

0 comments on commit f5af79c

Please sign in to comment.