From 2d08fc14b33263dfeaf76835c326b8915f01c11f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:10:50 +0000 Subject: [PATCH] feat!: Use registries of `Weak`s when doing resolution (#1781) We'll need this to resolve extension references inside an extension that's being created. drive-by: Make "node" optional in resolution errors. BREAKING CHANGE: `::extension::resolve` operations now use `WeakExtensionRegistry`es. --- hugr-core/src/extension/prelude.rs | 4 +- hugr-core/src/extension/resolution.rs | 55 ++++----- hugr-core/src/extension/resolution/ops.rs | 8 +- hugr-core/src/extension/resolution/test.rs | 8 +- hugr-core/src/extension/resolution/types.rs | 33 +++--- .../src/extension/resolution/types_mut.rs | 64 ++++++----- .../src/extension/resolution/weak_registry.rs | 107 ++++++++++++++++++ hugr-core/src/hugr.rs | 9 +- hugr-core/src/ops/constant.rs | 5 +- hugr-core/src/ops/constant/custom.rs | 10 +- .../src/std_extensions/collections/list.rs | 3 +- hugr-core/src/types/signature.rs | 8 +- 12 files changed, 221 insertions(+), 93 deletions(-) create mode 100644 hugr-core/src/extension/resolution/weak_registry.rs diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index a6c58cfc9..b28814c81 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -25,7 +25,7 @@ use crate::{type_row, Extension}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -use super::resolution::{resolve_type_extensions, ExtensionResolutionError}; +use super::resolution::{resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry}; use super::ExtensionRegistry; mod unwrap_builder; @@ -507,7 +507,7 @@ impl CustomConst for ConstExternalSymbol { fn update_extensions( &mut self, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { resolve_type_extensions(&mut self.typ, extensions) } diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 4901502b9..6c4dc683e 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -22,6 +22,9 @@ mod ops; mod types; mod types_mut; +mod weak_registry; + +pub use weak_registry::WeakExtensionRegistry; pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; pub(crate) use types::{collect_op_types_extensions, collect_signature_exts}; @@ -42,49 +45,37 @@ use crate::Node; /// Update all weak Extension pointers inside a type. pub fn resolve_type_extensions( typ: &mut TypeBase, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here. - // TODO: Make `node` optional in `ExtensionResolutionError` - let node: Node = portgraph::NodeIndex::new(0).into(); - let mut used_extensions = ExtensionRegistry::default(); - resolve_type_exts(node, typ, extensions, &mut used_extensions) + let mut used_extensions = WeakExtensionRegistry::default(); + resolve_type_exts(None, typ, extensions, &mut used_extensions) } /// Update all weak Extension pointers in a custom type. pub fn resolve_custom_type_extensions( typ: &mut CustomType, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here. - // TODO: Make `node` optional in `ExtensionResolutionError` - let node: Node = portgraph::NodeIndex::new(0).into(); - let mut used_extensions = ExtensionRegistry::default(); - resolve_custom_type_exts(node, typ, extensions, &mut used_extensions) + let mut used_extensions = WeakExtensionRegistry::default(); + resolve_custom_type_exts(None, typ, extensions, &mut used_extensions) } /// Update all weak Extension pointers inside a type argument. pub fn resolve_typearg_extensions( arg: &mut TypeArg, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here. - // TODO: Make `node` optional in `ExtensionResolutionError` - let node: Node = portgraph::NodeIndex::new(0).into(); - let mut used_extensions = ExtensionRegistry::default(); - resolve_typearg_exts(node, arg, extensions, &mut used_extensions) + let mut used_extensions = WeakExtensionRegistry::default(); + resolve_typearg_exts(None, arg, extensions, &mut used_extensions) } /// Update all weak Extension pointers inside a constant value. pub fn resolve_value_extensions( value: &mut Value, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here. - // TODO: Make `node` optional in `ExtensionResolutionError` - let node: Node = portgraph::NodeIndex::new(0).into(); - let mut used_extensions = ExtensionRegistry::default(); - resolve_value_exts(node, value, extensions, &mut used_extensions) + let mut used_extensions = WeakExtensionRegistry::default(); + resolve_value_exts(None, value, extensions, &mut used_extensions) } /// Errors that can occur during extension resolution. @@ -97,12 +88,13 @@ pub enum ExtensionResolutionError { OpaqueOpError(OpaqueOpError), /// An operation requires an extension that is not in the given registry. #[display( - "{op} ({node}) requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + "{op}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + node.map(|n| format!(" in {}", n)).unwrap_or_default(), available_extensions.join(", ") )] MissingOpExtension { /// The node that requires the extension. - node: Node, + node: Option, /// The operation that requires the extension. op: OpName, /// The missing extension @@ -111,13 +103,14 @@ pub enum ExtensionResolutionError { available_extensions: Vec, }, #[display( - "Type {ty} in {node} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + "Type {ty}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}", + node.map(|n| format!(" in {}", n)).unwrap_or_default(), available_extensions.join(", ") )] /// A type references an extension that is not in the given registry. MissingTypeExtension { /// The node that requires the extension. - node: Node, + node: Option, /// The type that requires the extension. ty: TypeName, /// The missing extension @@ -138,7 +131,7 @@ pub enum ExtensionResolutionError { impl ExtensionResolutionError { /// Create a new error for missing operation extensions. pub fn missing_op_extension( - node: Node, + node: Option, op: &OpType, missing_extension: &ExtensionId, extensions: &ExtensionRegistry, @@ -153,10 +146,10 @@ impl ExtensionResolutionError { /// Create a new error for missing type extensions. pub fn missing_type_extension( - node: Node, + node: Option, ty: &TypeName, missing_extension: &ExtensionId, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Self { Self::MissingTypeExtension { node, diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs index 8e9591688..78e6b3fbc 100644 --- a/hugr-core/src/extension/resolution/ops.rs +++ b/hugr-core/src/extension/resolution/ops.rs @@ -7,7 +7,8 @@ use std::sync::Arc; -use super::{Extension, ExtensionCollectionError, ExtensionRegistry, ExtensionResolutionError}; +use super::{Extension, ExtensionCollectionError, ExtensionResolutionError}; +use crate::extension::ExtensionRegistry; use crate::ops::custom::OpaqueOpError; use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType}; use crate::Node; @@ -124,7 +125,10 @@ fn operation_extension<'e>( match extensions.get(ext) { Some(e) => Ok(Some(e)), None => Err(ExtensionResolutionError::missing_op_extension( - node, op, ext, extensions, + Some(node), + op, + ext, + extensions, )), } } diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 0cf9fb0dc..b3113a030 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -11,6 +11,7 @@ use crate::builder::{ Container, Dataflow, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, usize_custom_t, ConstUsize}; +use crate::extension::resolution::WeakExtensionRegistry; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, }; @@ -52,9 +53,12 @@ fn resolve_type_extensions(#[case] op: impl Into, #[case] extensions: Ex let dummy_node = portgraph::NodeIndex::new(0).into(); - let mut used_exts = ExtensionRegistry::default(); resolve_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap(); - resolve_op_types_extensions(dummy_node, &mut deser_op, &extensions, &mut used_exts).unwrap(); + + let weak_extensions: WeakExtensionRegistry = (&extensions).into(); + resolve_op_types_extensions(Some(dummy_node), &mut deser_op, &weak_extensions) + .unwrap() + .for_each(|_| ()); let deser_extensions = deser_op.used_extensions().unwrap(); diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 81ad5e1f6..7559ea3fe 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -6,7 +6,7 @@ //! See [`super::resolve_op_types_extensions`] for a mutating version that //! updates the weak links to point to the correct extensions. -use super::ExtensionCollectionError; +use super::{ExtensionCollectionError, WeakExtensionRegistry}; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; use crate::types::type_row::TypeRowBase; @@ -32,7 +32,7 @@ pub(crate) fn collect_op_types_extensions( node: Option, op: &OpType, ) -> Result { - let mut used = ExtensionRegistry::default(); + let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); match op { @@ -101,12 +101,15 @@ pub(crate) fn collect_op_types_extensions( OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {} }; - missing - .is_empty() - .then_some(used) - .ok_or(ExtensionCollectionError::dropped_op_extension( + match missing.is_empty() { + true => { + // We know there are no missing extensions, so this should not fail. + Ok(used.try_into().expect("All extensions are valid")) + } + false => Err(ExtensionCollectionError::dropped_op_extension( node, op, missing, - )) + )), + } } /// Collect the Extension pointers in the [`CustomType`]s inside a signature. @@ -119,7 +122,7 @@ pub(crate) fn collect_op_types_extensions( /// `Weak` pointer has been invalidated. pub(crate) fn collect_signature_exts( signature: &FuncTypeBase, - used_extensions: &mut ExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { // Note that we do not include the signature's `runtime_reqs` here, as those refer @@ -138,7 +141,7 @@ pub(crate) fn collect_signature_exts( /// `Weak` pointer has been invalidated. fn collect_type_row_exts( row: &TypeRowBase, - used_extensions: &mut ExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { for ty in row.iter() { @@ -156,7 +159,7 @@ fn collect_type_row_exts( /// `Weak` pointer has been invalidated. pub(super) fn collect_type_exts( typ: &TypeBase, - used_extensions: &mut ExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { match typ.as_type_enum() { @@ -164,9 +167,11 @@ pub(super) fn collect_type_exts( for arg in custom.args() { collect_typearg_exts(arg, used_extensions, missing_extensions); } - match custom.extension_ref().upgrade() { + let ext_ref = custom.extension_ref(); + // Check if the extension reference is still valid. + match ext_ref.upgrade() { Some(ext) => { - used_extensions.register_updated(ext); + used_extensions.register(ext.name().clone(), ext_ref); } None => { missing_extensions.insert(custom.extension().clone()); @@ -200,7 +205,7 @@ pub(super) fn collect_type_exts( /// `Weak` pointer has been invalidated. fn collect_typearg_exts( arg: &TypeArg, - used_extensions: &mut ExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { match arg { @@ -226,7 +231,7 @@ fn collect_typearg_exts( /// `Weak` pointer has been invalidated. fn collect_value_exts( value: &Value, - used_extensions: &mut ExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { match value { diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 520bf2919..f84935aeb 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -3,30 +3,30 @@ //! //! For a non-mutating option see [`super::collect_op_types_extensions`]. -use std::sync::Arc; +use std::sync::Weak; use super::types::collect_type_exts; -use super::{ExtensionRegistry, ExtensionResolutionError}; +use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; use crate::types::type_row::TypeRowBase; use crate::types::{CustomType, MaybeRV, Signature, SumType, TypeArg, TypeBase, TypeEnum}; -use crate::Node; +use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an /// optype with a valid pointer to the extension in the `extensions` /// registry. /// -/// When a pointer is replaced, the extension is added to the -/// `used_extensions` registry. +/// Returns an iterator over the used extensions. /// /// This is a helper function used right after deserializing a Hugr. pub fn resolve_op_types_extensions( - node: Node, + node: Option, op: &mut OpType, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, -) -> Result<(), ExtensionResolutionError> { + extensions: &WeakExtensionRegistry, +) -> Result>, ExtensionResolutionError> { + let mut used = WeakExtensionRegistry::default(); + let used_extensions = &mut used; match op { OpType::ExtensionOp(ext) => { for arg in ext.args_mut() { @@ -106,17 +106,17 @@ pub fn resolve_op_types_extensions( // Ignore optypes that do not store a signature. OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {} } - Ok(()) + Ok(used.into_iter()) } /// Update all weak Extension pointers in the [`CustomType`]s inside a signature. /// /// Adds the extensions used in the signature to the `used_extensions` registry. fn resolve_signature_exts( - node: Node, + node: Option, signature: &mut Signature, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { // Note that we do not include the signature's `runtime_reqs` here, as those refer // to _runtime_ requirements that may not be currently present. @@ -129,10 +129,10 @@ fn resolve_signature_exts( /// /// Adds the extensions used in the row to the `used_extensions` registry. fn resolve_type_row_exts( - node: Node, + node: Option, row: &mut TypeRowBase, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { for ty in row.iter_mut() { resolve_type_exts(node, ty, extensions, used_extensions)?; @@ -144,10 +144,10 @@ fn resolve_type_row_exts( /// /// Adds the extensions used in the type to the `used_extensions` registry. pub(super) fn resolve_type_exts( - node: Node, + node: Option, typ: &mut TypeBase, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { match typ.as_type_enum_mut() { TypeEnum::Extension(custom) => { @@ -175,10 +175,10 @@ pub(super) fn resolve_type_exts( /// /// Adds the extensions used in the type to the `used_extensions` registry. pub(super) fn resolve_custom_type_exts( - node: Node, + node: Option, custom: &mut CustomType, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { for arg in custom.args_mut() { resolve_typearg_exts(node, arg, extensions, used_extensions)?; @@ -191,8 +191,8 @@ pub(super) fn resolve_custom_type_exts( // Add the extension to the used extensions registry, // and update the CustomType with the valid pointer. - used_extensions.register_updated_ref(ext); - custom.update_extension(Arc::downgrade(ext)); + used_extensions.register(ext_id.clone(), ext.clone()); + custom.update_extension(ext.clone()); Ok(()) } @@ -201,10 +201,10 @@ pub(super) fn resolve_custom_type_exts( /// /// Adds the extensions used in the type to the `used_extensions` registry. pub(super) fn resolve_typearg_exts( - node: Node, + node: Option, arg: &mut TypeArg, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { match arg { TypeArg::Type { ty } => resolve_type_exts(node, ty, extensions, used_extensions)?, @@ -222,10 +222,10 @@ pub(super) fn resolve_typearg_exts( /// /// Adds the extensions used in the row to the `used_extensions` registry. pub(super) fn resolve_value_exts( - node: Node, + node: Option, value: &mut Value, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, + extensions: &WeakExtensionRegistry, + used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { match value { Value::Extension { e } => { @@ -246,7 +246,9 @@ pub(super) fn resolve_value_exts( Value::Function { hugr } => { // We don't need to add the nested hugr's extensions to the main one here, // but we run resolution on it independently. - hugr.resolve_extension_defs(extensions)?; + if let Ok(exts) = extensions.try_into() { + hugr.resolve_extension_defs(&exts)?; + } } Value::Sum(s) => { if let SumType::General { rows } = &mut s.sum_type { diff --git a/hugr-core/src/extension/resolution/weak_registry.rs b/hugr-core/src/extension/resolution/weak_registry.rs new file mode 100644 index 000000000..85cff7dc8 --- /dev/null +++ b/hugr-core/src/extension/resolution/weak_registry.rs @@ -0,0 +1,107 @@ +use std::collections::BTreeMap; +use std::sync::{Arc, Weak}; + +use itertools::Itertools; + +use derive_more::Display; + +use crate::extension::{ExtensionId, ExtensionRegistry}; +use crate::Extension; + +/// The equivalent to an [`ExtensionRegistry`] that only contains weak +/// references. +/// +/// This is used to resolve extensions pointers while the extensions themselves +/// (and the [`Arc`] that contains them) are being initialized. +#[derive(Debug, Display, Default, Clone)] +#[display("WeakExtensionRegistry[{}]", exts.keys().join(", "))] +pub struct WeakExtensionRegistry { + /// The extensions in the registry. + exts: BTreeMap>, +} + +impl WeakExtensionRegistry { + /// Create a new weak registry from a list of extensions and their ids. + pub fn new(extensions: impl IntoIterator)>) -> Self { + let mut res = Self::default(); + for (id, ext) in extensions.into_iter() { + res.register(id, ext); + } + res + } + + /// Gets the Extension with the given name + pub fn get(&self, name: &str) -> Option<&Weak> { + self.exts.get(name) + } + + /// Returns `true` if the registry contains an extension with the given name. + pub fn contains(&self, name: &str) -> bool { + self.exts.contains_key(name) + } + + /// Register a new extension in the registry. + /// + /// Returns `true` if the extension was added, `false` if it was already present. + pub fn register(&mut self, id: ExtensionId, ext: impl Into>) -> bool { + self.exts.insert(id, ext.into()).is_none() + } + + /// Returns an iterator over the weak references in the registry and their ids. + pub fn iter(&self) -> impl Iterator)> { + self.exts.iter() + } + + /// Returns an iterator over the weak references in the registry. + pub fn extensions(&self) -> impl Iterator> { + self.exts.values() + } + + /// Returns an iterator over the extension ids in the registry. + pub fn ids(&self) -> impl Iterator { + self.exts.keys() + } +} + +impl IntoIterator for WeakExtensionRegistry { + type Item = Weak; + type IntoIter = std::collections::btree_map::IntoValues>; + + fn into_iter(self) -> Self::IntoIter { + self.exts.into_values() + } +} + +impl<'a> TryFrom<&'a WeakExtensionRegistry> for ExtensionRegistry { + type Error = (); + + fn try_from(weak: &'a WeakExtensionRegistry) -> Result { + let exts: Vec> = weak + .extensions() + .map(|w| w.upgrade().ok_or(())) + .try_collect()?; + Ok(ExtensionRegistry::new(exts)) + } +} + +impl TryFrom for ExtensionRegistry { + type Error = (); + + fn try_from(weak: WeakExtensionRegistry) -> Result { + let exts: Vec> = weak + .into_iter() + .map(|w| w.upgrade().ok_or(())) + .try_collect()?; + Ok(ExtensionRegistry::new(exts)) + } +} + +impl<'a> From<&'a ExtensionRegistry> for WeakExtensionRegistry { + fn from(reg: &'a ExtensionRegistry) -> Self { + let exts = reg + .iter() + .map(|ext| (ext.name().clone(), Arc::downgrade(ext))) + .collect(); + Self { exts } + } +} diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index e3392b601..b97435610 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -27,6 +27,7 @@ pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, + WeakExtensionRegistry, }; use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; use crate::ops::{OpTag, OpTrait}; @@ -231,6 +232,7 @@ impl Hugr { // Since we don't have a non-borrowing iterator over all the possible // NodeIds, we have to simulate it by iterating over all possible // indices and checking if the node exists. + let weak_extensions: WeakExtensionRegistry = extensions.into(); for n in 0..self.graph.node_capacity() { let pg_node = portgraph::NodeIndex::new(n); let node: Node = pg_node.into(); @@ -243,7 +245,12 @@ impl Hugr { if let Some(extension) = resolve_op_extensions(node, op, extensions)? { used_extensions.register_updated_ref(extension); } - resolve_op_types_extensions(node, op, extensions, &mut used_extensions)?; + used_extensions.extend( + resolve_op_types_extensions(Some(node), op, &weak_extensions)?.map(|weak| { + weak.upgrade() + .expect("Extension comes from a valid registry") + }), + ); } self.extensions = used_extensions; diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index a3cfd6c3f..dc8fb8d1b 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -579,8 +579,9 @@ pub(crate) mod test { use crate::extension::prelude::{bool_t, usize_custom_t}; use crate::extension::resolution::{ resolve_custom_type_extensions, resolve_typearg_extensions, ExtensionResolutionError, + WeakExtensionRegistry, }; - use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::extension::PRELUDE; use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, @@ -614,7 +615,7 @@ pub(crate) mod test { fn update_extensions( &mut self, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { resolve_custom_type_extensions(&mut self.0, extensions)?; // This loop is redundant, but we use it to test the public diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index 0b1a77899..985e15594 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -10,8 +10,10 @@ use std::hash::{Hash, Hasher}; use downcast_rs::{impl_downcast, Downcast}; use thiserror::Error; -use crate::extension::resolution::{resolve_type_extensions, ExtensionResolutionError}; -use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::extension::resolution::{ + resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry, +}; +use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::types::{CustomCheckFailure, Type}; use crate::IncomingPort; @@ -93,7 +95,7 @@ pub trait CustomConst: /// See the helper methods in [`crate::extension::resolution`]. fn update_extensions( &mut self, - _extensions: &ExtensionRegistry, + _extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { Ok(()) } @@ -316,7 +318,7 @@ impl CustomConst for CustomSerialized { } fn update_extensions( &mut self, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { resolve_type_extensions(&mut self.typ, extensions) } diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 1d5bbd4f0..426055251 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -15,6 +15,7 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::extension::prelude::{either_type, option_type, usize_t}; use crate::extension::resolution::{ resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError, + WeakExtensionRegistry, }; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; @@ -132,7 +133,7 @@ impl CustomConst for ListValue { fn update_extensions( &mut self, - extensions: &ExtensionRegistry, + extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { for val in &mut self.0 { resolve_value_extensions(val, extensions)?; diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 506869331..cac530291 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -10,7 +10,9 @@ use super::type_row::TypeRowBase; use super::{MaybeRV, NoRV, RowVariable, Substitution, Type, TypeRow}; use crate::core::PortIndex; -use crate::extension::resolution::{collect_signature_exts, ExtensionCollectionError}; +use crate::extension::resolution::{ + collect_signature_exts, ExtensionCollectionError, WeakExtensionRegistry, +}; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::{Direction, IncomingPort, OutgoingPort, Port}; @@ -131,13 +133,13 @@ impl FuncTypeBase { /// refer to _runtime_ extensions, which may not be in all places that /// manipulate a HUGR. pub fn used_extensions(&self) -> Result { - let mut used = ExtensionRegistry::default(); + let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); collect_signature_exts(self, &mut used, &mut missing); if missing.is_empty() { - Ok(used) + Ok(used.try_into().expect("all extensions are present")) } else { Err(ExtensionCollectionError::dropped_signature(self, missing)) }