Skip to content

Commit

Permalink
feat!: Use registries of Weak<Extension>s when doing resolution (#1781
Browse files Browse the repository at this point in the history
)

We'll need this to resolve extension references inside an extension
that's being created.

drive-by: Make "node" optional in resolution errors.

BREAKING CHANGE: `::extension::resolve` operations now use
`WeakExtensionRegistry`es.
  • Loading branch information
aborgna-q authored Dec 13, 2024
1 parent 8722c10 commit 2d08fc1
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 93 deletions.
4 changes: 2 additions & 2 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use super::resolution::{resolve_type_extensions, ExtensionResolutionError};
use super::resolution::{resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry};
use super::ExtensionRegistry;

mod unwrap_builder;
Expand Down Expand Up @@ -507,7 +507,7 @@ impl CustomConst for ConstExternalSymbol {

fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
resolve_type_extensions(&mut self.typ, extensions)
}
Expand Down
55 changes: 24 additions & 31 deletions hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
mod ops;
mod types;
mod types_mut;
mod weak_registry;

pub use weak_registry::WeakExtensionRegistry;

pub(crate) use ops::{collect_op_extension, resolve_op_extensions};
pub(crate) use types::{collect_op_types_extensions, collect_signature_exts};
Expand All @@ -42,49 +45,37 @@ use crate::Node;
/// Update all weak Extension pointers inside a type.
pub fn resolve_type_extensions<RV: MaybeRV>(
typ: &mut TypeBase<RV>,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_type_exts(node, typ, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_type_exts(None, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers in a custom type.
pub fn resolve_custom_type_extensions(
typ: &mut CustomType,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_custom_type_exts(node, typ, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_custom_type_exts(None, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a type argument.
pub fn resolve_typearg_extensions(
arg: &mut TypeArg,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_typearg_exts(node, arg, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_typearg_exts(None, arg, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a constant value.
pub fn resolve_value_extensions(
value: &mut Value,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_value_exts(node, value, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_value_exts(None, value, extensions, &mut used_extensions)
}

/// Errors that can occur during extension resolution.
Expand All @@ -97,12 +88,13 @@ pub enum ExtensionResolutionError {
OpaqueOpError(OpaqueOpError),
/// An operation requires an extension that is not in the given registry.
#[display(
"{op} ({node}) requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
"{op}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
node.map(|n| format!(" in {}", n)).unwrap_or_default(),
available_extensions.join(", ")
)]
MissingOpExtension {
/// The node that requires the extension.
node: Node,
node: Option<Node>,
/// The operation that requires the extension.
op: OpName,
/// The missing extension
Expand All @@ -111,13 +103,14 @@ pub enum ExtensionResolutionError {
available_extensions: Vec<ExtensionId>,
},
#[display(
"Type {ty} in {node} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
"Type {ty}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
node.map(|n| format!(" in {}", n)).unwrap_or_default(),
available_extensions.join(", ")
)]
/// A type references an extension that is not in the given registry.
MissingTypeExtension {
/// The node that requires the extension.
node: Node,
node: Option<Node>,
/// The type that requires the extension.
ty: TypeName,
/// The missing extension
Expand All @@ -138,7 +131,7 @@ pub enum ExtensionResolutionError {
impl ExtensionResolutionError {
/// Create a new error for missing operation extensions.
pub fn missing_op_extension(
node: Node,
node: Option<Node>,
op: &OpType,
missing_extension: &ExtensionId,
extensions: &ExtensionRegistry,
Expand All @@ -153,10 +146,10 @@ impl ExtensionResolutionError {

/// Create a new error for missing type extensions.
pub fn missing_type_extension(
node: Node,
node: Option<Node>,
ty: &TypeName,
missing_extension: &ExtensionId,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Self {
Self::MissingTypeExtension {
node,
Expand Down
8 changes: 6 additions & 2 deletions hugr-core/src/extension/resolution/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
use std::sync::Arc;

use super::{Extension, ExtensionCollectionError, ExtensionRegistry, ExtensionResolutionError};
use super::{Extension, ExtensionCollectionError, ExtensionResolutionError};
use crate::extension::ExtensionRegistry;
use crate::ops::custom::OpaqueOpError;
use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType};
use crate::Node;
Expand Down Expand Up @@ -124,7 +125,10 @@ fn operation_extension<'e>(
match extensions.get(ext) {
Some(e) => Ok(Some(e)),
None => Err(ExtensionResolutionError::missing_op_extension(
node, op, ext, extensions,
Some(node),
op,
ext,
extensions,
)),
}
}
8 changes: 6 additions & 2 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::builder::{
Container, Dataflow, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{bool_t, usize_custom_t, ConstUsize};
use crate::extension::resolution::WeakExtensionRegistry;
use crate::extension::resolution::{
resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError,
};
Expand Down Expand Up @@ -52,9 +53,12 @@ fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: Ex

let dummy_node = portgraph::NodeIndex::new(0).into();

let mut used_exts = ExtensionRegistry::default();
resolve_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap();
resolve_op_types_extensions(dummy_node, &mut deser_op, &extensions, &mut used_exts).unwrap();

let weak_extensions: WeakExtensionRegistry = (&extensions).into();
resolve_op_types_extensions(Some(dummy_node), &mut deser_op, &weak_extensions)
.unwrap()
.for_each(|_| ());

let deser_extensions = deser_op.used_extensions().unwrap();

Expand Down
33 changes: 19 additions & 14 deletions hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! See [`super::resolve_op_types_extensions`] for a mutating version that
//! updates the weak links to point to the correct extensions.
use super::ExtensionCollectionError;
use super::{ExtensionCollectionError, WeakExtensionRegistry};
use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::ops::{DataflowOpTrait, OpType, Value};
use crate::types::type_row::TypeRowBase;
Expand All @@ -32,7 +32,7 @@ pub(crate) fn collect_op_types_extensions(
node: Option<Node>,
op: &OpType,
) -> Result<ExtensionRegistry, ExtensionCollectionError> {
let mut used = ExtensionRegistry::default();
let mut used = WeakExtensionRegistry::default();
let mut missing = ExtensionSet::new();

match op {
Expand Down Expand Up @@ -101,12 +101,15 @@ pub(crate) fn collect_op_types_extensions(
OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {}
};

missing
.is_empty()
.then_some(used)
.ok_or(ExtensionCollectionError::dropped_op_extension(
match missing.is_empty() {
true => {
// We know there are no missing extensions, so this should not fail.
Ok(used.try_into().expect("All extensions are valid"))
}
false => Err(ExtensionCollectionError::dropped_op_extension(
node, op, missing,
))
)),
}
}

/// Collect the Extension pointers in the [`CustomType`]s inside a signature.
Expand All @@ -119,7 +122,7 @@ pub(crate) fn collect_op_types_extensions(
/// `Weak<Extension>` pointer has been invalidated.
pub(crate) fn collect_signature_exts<RV: MaybeRV>(
signature: &FuncTypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
// Note that we do not include the signature's `runtime_reqs` here, as those refer
Expand All @@ -138,7 +141,7 @@ pub(crate) fn collect_signature_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_type_row_exts<RV: MaybeRV>(
row: &TypeRowBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
for ty in row.iter() {
Expand All @@ -156,17 +159,19 @@ fn collect_type_row_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
pub(super) fn collect_type_exts<RV: MaybeRV>(
typ: &TypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match typ.as_type_enum() {
TypeEnum::Extension(custom) => {
for arg in custom.args() {
collect_typearg_exts(arg, used_extensions, missing_extensions);
}
match custom.extension_ref().upgrade() {
let ext_ref = custom.extension_ref();
// Check if the extension reference is still valid.
match ext_ref.upgrade() {
Some(ext) => {
used_extensions.register_updated(ext);
used_extensions.register(ext.name().clone(), ext_ref);
}
None => {
missing_extensions.insert(custom.extension().clone());
Expand Down Expand Up @@ -200,7 +205,7 @@ pub(super) fn collect_type_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_typearg_exts(
arg: &TypeArg,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match arg {
Expand All @@ -226,7 +231,7 @@ fn collect_typearg_exts(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_value_exts(
value: &Value,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match value {
Expand Down
Loading

0 comments on commit 2d08fc1

Please sign in to comment.