Skip to content

Commit

Permalink
feat!: OpDefs and TypeDefs keep a reference to their extension (#1719)
Browse files Browse the repository at this point in the history
This change was extracted from the work towards #1613.
Now `OpDef`s and `TypeDef`s keep a `Weak` reference to their extension's
`Arc`.

This way we will be able to automatically set the extension requirements
when adding operations, so we can get rid of `update_validate` and the
explicit registries when building hugrs.

To implement this, the building interface for `Extension`s is sightly
modified.
Once an `Arc` is built it cannot be modified without doing internal
mutation.
But we need the `Arc`'s weak reference to define the ops and types.
Thankfully, we can use `Arc::new_cyclic` which provides us with a `Weak`
ref at build time so we are able to define things as needed.

This is wrapped in a new `Extension::new_arc` method, so the user
doesn't need to think about that.

BREAKING CHANGE: Renamed `OpDef::extension` and `TypeDef::extension` to
`extension_id`. `extension` now returns weak references to the
`Extension` defining them.
BREAKING CHANGE: `Extension::with_reqs` moved to `set_reqs`, which takes
`&mut self` instead of `self`.
BREAKING CHANGE: `Extension::add_type` and `Extension::add_op` now take
an extra parameter. See docs for example usage.
BREAKING CHANGE: `ExtensionRegistry::register_updated` and
`register_updated_ref` are no longer fallible.
  • Loading branch information
aborgna-q authored Nov 27, 2024
1 parent 344a7e4 commit 517fd3d
Show file tree
Hide file tree
Showing 30 changed files with 837 additions and 568 deletions.
4 changes: 1 addition & 3 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use clap::Parser;
use clap_verbosity_flag::Level;
use hugr::package::PackageValidationError;
use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs};
Expand Down Expand Up @@ -64,8 +63,7 @@ impl HugrArgs {
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext)
.map_err(PackageValidationError::Extension)?;
reg.register_updated(ext);
}

package.update_validate(&mut reg)?;
Expand Down
16 changes: 13 additions & 3 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::{ExtensionId, ExtensionSet};
use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
Expand Down Expand Up @@ -298,8 +298,18 @@ mod test {
#[test]
fn with_nonlinear_and_outputs() {
let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
let mut my_ext = Extension::new_test(my_ext_name.clone());
let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB]));
let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
ext.add_op(
"MyOp".into(),
"".to_string(),
Signature::new(vec![QB, NAT], vec![QB]),
extension_ref,
)
.unwrap();
});
let my_custom_op = my_ext
.instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY)
.unwrap();

let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ impl<'a> Context<'a> {

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
_ => return self.make_named_global_ref(opdef.extension(), opdef.name()),
_ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()),
};

let key = (opdef.extension().clone(), opdef.name().clone());
let key = (opdef.extension_id().clone(), opdef.name().clone());
let entry = self.decl_operations.entry(key);

let node = match entry {
Expand All @@ -467,7 +467,7 @@ impl<'a> Context<'a> {
};

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let name = this.make_qualified_name(opdef.extension_id(), opdef.name());
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
Expand Down
153 changes: 121 additions & 32 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ pub use semver::Version;
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use std::mem;
use std::sync::{Arc, Weak};

use thiserror::Error;

Expand Down Expand Up @@ -103,10 +104,7 @@ impl ExtensionRegistry {
///
/// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see
/// [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(
&mut self,
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
Expand All @@ -118,7 +116,6 @@ impl ExtensionRegistry {
ve.insert(extension);
}
}
Ok(())
}

/// Registers a new extension to the registry, keeping most up to date if
Expand All @@ -130,10 +127,7 @@ impl ExtensionRegistry {
///
/// Clones the Arc only when required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(
&mut self,
extension: &Arc<Extension>,
) -> Result<(), ExtensionRegistryError> {
pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
Expand All @@ -144,7 +138,6 @@ impl ExtensionRegistry {
ve.insert(extension.clone());
}
}
Ok(())
}

/// Returns the number of extensions in the registry.
Expand Down Expand Up @@ -335,6 +328,45 @@ impl ExtensionValue {
pub type ExtensionId = IdentList;

/// A extension is a set of capabilities required to execute a graph.
///
/// These are normally defined once and shared across multiple graphs and
/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`].
///
/// # Example
///
/// The following example demonstrates how to define a new extension with a
/// custom operation and a custom type.
///
/// When using `arc`s, the extension can only be modified at creation time. The
/// defined operations and types keep a [`Weak`] reference to their extension. We provide a
/// helper method [`Extension::new_arc`] to aid their definition.
///
/// ```
/// # use hugr_core::types::Signature;
/// # use hugr_core::extension::{Extension, ExtensionId, Version};
/// # use hugr_core::extension::{TypeDefBound};
/// Extension::new_arc(
/// ExtensionId::new_unchecked("my.extension"),
/// Version::new(0, 1, 0),
/// |ext, extension_ref| {
/// // Add a custom type definition
/// ext.add_type(
/// "MyType".into(),
/// vec![], // No type parameters
/// "Some type".into(),
/// TypeDefBound::any(),
/// extension_ref,
/// );
/// // Add a custom operation
/// ext.add_op(
/// "MyOp".into(),
/// "Some operation".into(),
/// Signature::new_endo(vec![]),
/// extension_ref,
/// );
/// },
/// );
/// ```
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
/// Extension version, follows semver.
Expand All @@ -361,6 +393,12 @@ pub struct Extension {

impl Extension {
/// Creates a new extension with the given name.
///
/// In most cases extensions are contained inside an [`Arc`] so that they
/// can be shared across hugr instances and operation definitions.
///
/// See [`Extension::new_arc`] for a more ergonomic way to create boxed
/// extensions.
pub fn new(name: ExtensionId, version: Version) -> Self {
Self {
name,
Expand All @@ -372,14 +410,63 @@ impl Extension {
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn with_reqs(self, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
extension_reqs: self.extension_reqs.union(extension_reqs.into()),
..self
/// Creates a new extension wrapped in an [`Arc`].
///
/// The closure lets us use a weak reference to the arc while the extension
/// is being built. This is necessary for calling [`Extension::add_op`] and
/// [`Extension::add_type`].
pub fn new_arc(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
init(&mut ext, extension_ref);
ext
})
}

/// Creates a new extension wrapped in an [`Arc`], using a fallible
/// initialization function.
///
/// The closure lets us use a weak reference to the arc while the extension
/// is being built. This is necessary for calling [`Extension::add_op`] and
/// [`Extension::add_type`].
pub fn try_new_arc<E>(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
) -> Result<Arc<Self>, E> {
// Annoying hack around not having `Arc::try_new_cyclic` that can return
// a Result.
// https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381
//
// When there is an error, we store it in `error` and return it at the
// end instead of the partially-initialized extension.
let mut error = None;
let ext = Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
match init(&mut ext, extension_ref) {
Ok(_) => ext,
Err(e) => {
error = Some(e);
ext
}
}
});
match error {
Some(e) => Err(e),
None => Ok(ext),
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn add_requirements(&mut self, extension_reqs: impl Into<ExtensionSet>) {
let reqs = mem::take(&mut self.extension_reqs);
self.extension_reqs = reqs.union(extension_reqs.into());
}

/// Allows read-only access to the operations in this Extension
pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(name)
Expand Down Expand Up @@ -634,20 +721,22 @@ pub mod test {

impl Extension {
/// Create a new extension for testing, with a 0 version.
pub(crate) fn new_test(name: ExtensionId) -> Self {
Self::new(name, Version::new(0, 0, 0))
pub(crate) fn new_test_arc(
name: ExtensionId,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Self::new_arc(name, Version::new(0, 0, 0), init)
}

/// Add a simple OpDef to the extension and return an extension op for it.
/// No description, no type parameters.
pub(crate) fn simple_ext_op(
&mut self,
name: &str,
signature: impl Into<SignatureFunc>,
) -> ExtensionOp {
self.add_op(name.into(), "".to_string(), signature).unwrap();
self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY)
.unwrap()
/// Create a new extension for testing, with a 0 version.
pub(crate) fn try_new_test_arc(
name: ExtensionId,
init: impl FnOnce(
&mut Extension,
&Weak<Extension>,
) -> Result<(), Box<dyn std::error::Error>>,
) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
Self::try_new_arc(name, Version::new(0, 0, 0), init)
}
}

Expand Down Expand Up @@ -680,14 +769,14 @@ pub mod test {
);

// register with update works
reg_ref.register_updated_ref(&ext1_1).unwrap();
reg.register_updated(ext1_1.clone()).unwrap();
reg_ref.register_updated_ref(&ext1_1);
reg.register_updated(ext1_1.clone());
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

// register with lower version does not change version
reg_ref.register_updated_ref(&ext1_2).unwrap();
reg.register_updated(ext1_2.clone()).unwrap();
reg_ref.register_updated_ref(&ext1_2);
reg.register_updated(ext1_2.clone());
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

Expand Down
32 changes: 19 additions & 13 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod types;

use std::fs::File;
use std::path::Path;
use std::sync::Arc;

use crate::extension::prelude::PRELUDE_ID;
use crate::ops::OpName;
Expand Down Expand Up @@ -150,19 +151,24 @@ impl ExtensionDeclaration {
&self,
imports: &ExtensionSet,
ctx: DeclarationContext<'_>,
) -> Result<Extension, ExtensionDeclarationError> {
let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0))
.with_reqs(imports.clone());

for t in &self.types {
t.register(&mut ext, ctx)?;
}

for o in &self.operations {
o.register(&mut ext, ctx)?;
}

Ok(ext)
) -> Result<Arc<Extension>, ExtensionDeclarationError> {
Extension::try_new_arc(
self.name.clone(),
// TODO: Get the version as a parameter.
crate::extension::Version::new(0, 0, 0),
|ext, extension_ref| {
for t in &self.types {
t.register(ext, ctx, extension_ref)?;
}

for o in &self.operations {
o.register(ext, ctx, extension_ref)?;
}
ext.add_requirements(imports.clone());

Ok(())
},
)
}
}

Expand Down
12 changes: 11 additions & 1 deletion hugr-core/src/extension/declarative/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration
use std::collections::HashMap;
use std::sync::Weak;

use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
Expand Down Expand Up @@ -55,10 +56,14 @@ pub(super) struct OperationDeclaration {

impl OperationDeclaration {
/// Register this operation in the given extension.
///
/// Requires a [`Weak`] reference to the extension defining the operation.
/// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
pub fn register<'ext>(
&self,
ext: &'ext mut Extension,
ctx: DeclarationContext<'_>,
extension_ref: &Weak<Extension>,
) -> Result<&'ext mut OpDef, ExtensionDeclarationError> {
// We currently only support explicit signatures.
//
Expand Down Expand Up @@ -88,7 +93,12 @@ impl OperationDeclaration {

let signature_func: SignatureFunc = signature.make_signature(ext, ctx, &params)?;

let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?;
let op_def = ext.add_op(
self.name.clone(),
self.description.clone(),
signature_func,
extension_ref,
)?;

for (k, v) in &self.misc {
op_def.add_misc(k, v.clone());
Expand Down
Loading

0 comments on commit 517fd3d

Please sign in to comment.