Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Use registries of Weak<Extension>s when doing resolution #1781

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use super::resolution::{resolve_type_extensions, ExtensionResolutionError};
use super::resolution::{resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry};
use super::ExtensionRegistry;

mod unwrap_builder;
Expand Down Expand Up @@ -507,7 +507,7 @@ impl CustomConst for ConstExternalSymbol {

fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
resolve_type_extensions(&mut self.typ, extensions)
}
Expand Down
55 changes: 24 additions & 31 deletions hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
mod ops;
mod types;
mod types_mut;
mod weak_registry;

pub use weak_registry::WeakExtensionRegistry;

pub(crate) use ops::{collect_op_extension, resolve_op_extensions};
pub(crate) use types::{collect_op_types_extensions, collect_signature_exts};
Expand All @@ -42,49 +45,37 @@ use crate::Node;
/// Update all weak Extension pointers inside a type.
pub fn resolve_type_extensions<RV: MaybeRV>(
typ: &mut TypeBase<RV>,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_type_exts(node, typ, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_type_exts(None, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers in a custom type.
pub fn resolve_custom_type_extensions(
typ: &mut CustomType,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_custom_type_exts(node, typ, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_custom_type_exts(None, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a type argument.
pub fn resolve_typearg_extensions(
arg: &mut TypeArg,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_typearg_exts(node, arg, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_typearg_exts(None, arg, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a constant value.
pub fn resolve_value_extensions(
value: &mut Value,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_value_exts(node, value, extensions, &mut used_extensions)
let mut used_extensions = WeakExtensionRegistry::default();
resolve_value_exts(None, value, extensions, &mut used_extensions)
}

/// Errors that can occur during extension resolution.
Expand All @@ -97,12 +88,13 @@ pub enum ExtensionResolutionError {
OpaqueOpError(OpaqueOpError),
/// An operation requires an extension that is not in the given registry.
#[display(
"{op} ({node}) requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
"{op}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
node.map(|n| format!(" in {}", n)).unwrap_or_default(),
available_extensions.join(", ")
)]
MissingOpExtension {
/// The node that requires the extension.
node: Node,
node: Option<Node>,
/// The operation that requires the extension.
op: OpName,
/// The missing extension
Expand All @@ -111,13 +103,14 @@ pub enum ExtensionResolutionError {
available_extensions: Vec<ExtensionId>,
},
#[display(
"Type {ty} in {node} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
"Type {ty}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
node.map(|n| format!(" in {}", n)).unwrap_or_default(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the default behaviour?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a node: "Type Unit in Node(4) requires..."
If there's no node: "Type Unit requires ..."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah sorry misread, thanks

available_extensions.join(", ")
)]
/// A type references an extension that is not in the given registry.
MissingTypeExtension {
/// The node that requires the extension.
node: Node,
node: Option<Node>,
/// The type that requires the extension.
ty: TypeName,
/// The missing extension
Expand All @@ -138,7 +131,7 @@ pub enum ExtensionResolutionError {
impl ExtensionResolutionError {
/// Create a new error for missing operation extensions.
pub fn missing_op_extension(
node: Node,
node: Option<Node>,
op: &OpType,
missing_extension: &ExtensionId,
extensions: &ExtensionRegistry,
Expand All @@ -153,10 +146,10 @@ impl ExtensionResolutionError {

/// Create a new error for missing type extensions.
pub fn missing_type_extension(
node: Node,
node: Option<Node>,
ty: &TypeName,
missing_extension: &ExtensionId,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Self {
Self::MissingTypeExtension {
node,
Expand Down
8 changes: 6 additions & 2 deletions hugr-core/src/extension/resolution/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

use std::sync::Arc;

use super::{Extension, ExtensionCollectionError, ExtensionRegistry, ExtensionResolutionError};
use super::{Extension, ExtensionCollectionError, ExtensionResolutionError};
use crate::extension::ExtensionRegistry;
use crate::ops::custom::OpaqueOpError;
use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType};
use crate::Node;
Expand Down Expand Up @@ -124,7 +125,10 @@ fn operation_extension<'e>(
match extensions.get(ext) {
Some(e) => Ok(Some(e)),
None => Err(ExtensionResolutionError::missing_op_extension(
node, op, ext, extensions,
Some(node),
op,
ext,
extensions,
)),
}
}
8 changes: 6 additions & 2 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::builder::{
Container, Dataflow, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{bool_t, usize_custom_t, ConstUsize};
use crate::extension::resolution::WeakExtensionRegistry;
use crate::extension::resolution::{
resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError,
};
Expand Down Expand Up @@ -52,9 +53,12 @@ fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: Ex

let dummy_node = portgraph::NodeIndex::new(0).into();

let mut used_exts = ExtensionRegistry::default();
resolve_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap();
resolve_op_types_extensions(dummy_node, &mut deser_op, &extensions, &mut used_exts).unwrap();

let weak_extensions: WeakExtensionRegistry = (&extensions).into();
resolve_op_types_extensions(Some(dummy_node), &mut deser_op, &weak_extensions)
.unwrap()
.for_each(|_| ());

let deser_extensions = deser_op.used_extensions().unwrap();

Expand Down
33 changes: 19 additions & 14 deletions hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! See [`super::resolve_op_types_extensions`] for a mutating version that
//! updates the weak links to point to the correct extensions.

use super::ExtensionCollectionError;
use super::{ExtensionCollectionError, WeakExtensionRegistry};
use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::ops::{DataflowOpTrait, OpType, Value};
use crate::types::type_row::TypeRowBase;
Expand All @@ -32,7 +32,7 @@ pub(crate) fn collect_op_types_extensions(
node: Option<Node>,
op: &OpType,
) -> Result<ExtensionRegistry, ExtensionCollectionError> {
let mut used = ExtensionRegistry::default();
let mut used = WeakExtensionRegistry::default();
let mut missing = ExtensionSet::new();

match op {
Expand Down Expand Up @@ -101,12 +101,15 @@ pub(crate) fn collect_op_types_extensions(
OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {}
};

missing
.is_empty()
.then_some(used)
.ok_or(ExtensionCollectionError::dropped_op_extension(
match missing.is_empty() {
true => {
// We know there are no missing extensions, so this should not fail.
Ok(used.try_into().expect("All extensions are valid"))
}
false => Err(ExtensionCollectionError::dropped_op_extension(
node, op, missing,
))
)),
}
}

/// Collect the Extension pointers in the [`CustomType`]s inside a signature.
Expand All @@ -119,7 +122,7 @@ pub(crate) fn collect_op_types_extensions(
/// `Weak<Extension>` pointer has been invalidated.
pub(crate) fn collect_signature_exts<RV: MaybeRV>(
signature: &FuncTypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
// Note that we do not include the signature's `runtime_reqs` here, as those refer
Expand All @@ -138,7 +141,7 @@ pub(crate) fn collect_signature_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_type_row_exts<RV: MaybeRV>(
row: &TypeRowBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
for ty in row.iter() {
Expand All @@ -156,17 +159,19 @@ fn collect_type_row_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
pub(super) fn collect_type_exts<RV: MaybeRV>(
typ: &TypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match typ.as_type_enum() {
TypeEnum::Extension(custom) => {
for arg in custom.args() {
collect_typearg_exts(arg, used_extensions, missing_extensions);
}
match custom.extension_ref().upgrade() {
let ext_ref = custom.extension_ref();
// Check if the extension reference is still valid.
match ext_ref.upgrade() {
Some(ext) => {
used_extensions.register_updated(ext);
used_extensions.register(ext.name().clone(), ext_ref);
}
None => {
missing_extensions.insert(custom.extension().clone());
Expand Down Expand Up @@ -200,7 +205,7 @@ pub(super) fn collect_type_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_typearg_exts(
arg: &TypeArg,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match arg {
Expand All @@ -226,7 +231,7 @@ fn collect_typearg_exts(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_value_exts(
value: &Value,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match value {
Expand Down
Loading
Loading