Skip to content

Commit

Permalink
Use weak registries in resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 12, 2024
1 parent e94171d commit 207ae65
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 63 deletions.
4 changes: 2 additions & 2 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use super::resolution::{resolve_type_extensions, ExtensionResolutionError};
use super::resolution::{resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry};
use super::ExtensionRegistry;

mod unwrap_builder;
Expand Down Expand Up @@ -507,7 +507,7 @@ impl CustomConst for ConstExternalSymbol {

fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
resolve_type_extensions(&mut self.typ, extensions)
}
Expand Down
21 changes: 12 additions & 9 deletions hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
mod ops;
mod types;
mod types_mut;
mod weak_registry;

pub use weak_registry::WeakExtensionRegistry;

pub(crate) use ops::{collect_op_extension, resolve_op_extensions};
pub(crate) use types::{collect_op_types_extensions, collect_signature_exts};
Expand All @@ -42,36 +45,36 @@ use crate::Node;
/// Update all weak Extension pointers inside a type.
pub fn resolve_type_extensions<RV: MaybeRV>(
typ: &mut TypeBase<RV>,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let mut used_extensions = ExtensionRegistry::default();
let mut used_extensions = WeakExtensionRegistry::default();
resolve_type_exts(None, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers in a custom type.
pub fn resolve_custom_type_extensions(
typ: &mut CustomType,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let mut used_extensions = ExtensionRegistry::default();
let mut used_extensions = WeakExtensionRegistry::default();
resolve_custom_type_exts(None, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a type argument.
pub fn resolve_typearg_extensions(
arg: &mut TypeArg,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let mut used_extensions = ExtensionRegistry::default();
let mut used_extensions = WeakExtensionRegistry::default();
resolve_typearg_exts(None, arg, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a constant value.
pub fn resolve_value_extensions(
value: &mut Value,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
let mut used_extensions = ExtensionRegistry::default();
let mut used_extensions = WeakExtensionRegistry::default();
resolve_value_exts(None, value, extensions, &mut used_extensions)
}

Expand Down Expand Up @@ -146,7 +149,7 @@ impl ExtensionResolutionError {
node: Option<Node>,
ty: &TypeName,
missing_extension: &ExtensionId,
extensions: &ExtensionRegistry,
extensions: &WeakExtensionRegistry,
) -> Self {
Self::MissingTypeExtension {
node,
Expand Down
3 changes: 2 additions & 1 deletion hugr-core/src/extension/resolution/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
use std::sync::Arc;

use super::{Extension, ExtensionCollectionError, ExtensionRegistry, ExtensionResolutionError};
use super::{Extension, ExtensionCollectionError, ExtensionResolutionError};
use crate::extension::ExtensionRegistry;
use crate::ops::custom::OpaqueOpError;
use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType};
use crate::Node;
Expand Down
11 changes: 7 additions & 4 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use rstest::rstest;
use crate::builder::{
Container, Dataflow, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{bool_t, usize_custom_t, ConstUsize};
use crate::extension::prelude::{bool_t, usize_custom_t, ConstUsize, PRELUDE_ID};
use crate::extension::resolution::WeakExtensionRegistry;
use crate::extension::resolution::{
resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError,
};
Expand Down Expand Up @@ -52,10 +53,12 @@ fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: Ex

let dummy_node = portgraph::NodeIndex::new(0).into();

let mut used_exts = ExtensionRegistry::default();
resolve_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap();
resolve_op_types_extensions(Some(dummy_node), &mut deser_op, &extensions, &mut used_exts)
.unwrap();

let weak_extensions: WeakExtensionRegistry = (&extensions).into();
resolve_op_types_extensions(Some(dummy_node), &mut deser_op, &weak_extensions)
.unwrap()
.for_each(|_| ());

let deser_extensions = deser_op.used_extensions().unwrap();

Expand Down
33 changes: 19 additions & 14 deletions hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! See [`super::resolve_op_types_extensions`] for a mutating version that
//! updates the weak links to point to the correct extensions.
use super::ExtensionCollectionError;
use super::{ExtensionCollectionError, WeakExtensionRegistry};
use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::ops::{DataflowOpTrait, OpType, Value};
use crate::types::type_row::TypeRowBase;
Expand All @@ -32,7 +32,7 @@ pub(crate) fn collect_op_types_extensions(
node: Option<Node>,
op: &OpType,
) -> Result<ExtensionRegistry, ExtensionCollectionError> {
let mut used = ExtensionRegistry::default();
let mut used = WeakExtensionRegistry::default();
let mut missing = ExtensionSet::new();

match op {
Expand Down Expand Up @@ -101,12 +101,15 @@ pub(crate) fn collect_op_types_extensions(
OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {}
};

missing
.is_empty()
.then_some(used)
.ok_or(ExtensionCollectionError::dropped_op_extension(
match missing.is_empty() {
true => {
// We know there are no missing extensions, so this should not fail.
Ok(used.try_into().expect("All extensions are valid"))
}
false => Err(ExtensionCollectionError::dropped_op_extension(
node, op, missing,
))
)),
}
}

/// Collect the Extension pointers in the [`CustomType`]s inside a signature.
Expand All @@ -119,7 +122,7 @@ pub(crate) fn collect_op_types_extensions(
/// `Weak<Extension>` pointer has been invalidated.
pub(crate) fn collect_signature_exts<RV: MaybeRV>(
signature: &FuncTypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
// Note that we do not include the signature's `runtime_reqs` here, as those refer
Expand All @@ -138,7 +141,7 @@ pub(crate) fn collect_signature_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_type_row_exts<RV: MaybeRV>(
row: &TypeRowBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
for ty in row.iter() {
Expand All @@ -156,17 +159,19 @@ fn collect_type_row_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
pub(super) fn collect_type_exts<RV: MaybeRV>(
typ: &TypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match typ.as_type_enum() {
TypeEnum::Extension(custom) => {
for arg in custom.args() {
collect_typearg_exts(arg, used_extensions, missing_extensions);
}
match custom.extension_ref().upgrade() {
let ext_ref = custom.extension_ref();
// Check if the extension reference is still valid.
match ext_ref.upgrade() {
Some(ext) => {
used_extensions.register_updated(ext);
used_extensions.register(ext.name().clone(), ext_ref);
}
None => {
missing_extensions.insert(custom.extension().clone());
Expand Down Expand Up @@ -200,7 +205,7 @@ pub(super) fn collect_type_exts<RV: MaybeRV>(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_typearg_exts(
arg: &TypeArg,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match arg {
Expand All @@ -226,7 +231,7 @@ fn collect_typearg_exts(
/// `Weak<Extension>` pointer has been invalidated.
fn collect_value_exts(
value: &Value,
used_extensions: &mut ExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
) {
match value {
Expand Down
49 changes: 27 additions & 22 deletions hugr-core/src/extension/resolution/types_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
//!
//! For a non-mutating option see [`super::collect_op_types_extensions`].
use std::sync::Arc;
use std::sync::Weak;

use super::types::collect_type_exts;
use super::{ExtensionRegistry, ExtensionResolutionError};
use super::{ExtensionResolutionError, WeakExtensionRegistry};
use crate::extension::ExtensionSet;
use crate::ops::{OpType, Value};
use crate::types::type_row::TypeRowBase;
use crate::types::{CustomType, MaybeRV, Signature, SumType, TypeArg, TypeBase, TypeEnum};
use crate::Node;
use crate::{Extension, Node};

/// Replace the dangling extension pointer in the [`CustomType`]s inside an
/// optype with a valid pointer to the extension in the `extensions`
Expand All @@ -20,13 +20,16 @@ use crate::Node;
/// When a pointer is replaced, the extension is added to the
/// `used_extensions` registry.
///
/// Returns
///
/// This is a helper function used right after deserializing a Hugr.
pub fn resolve_op_types_extensions(
node: Option<Node>,
op: &mut OpType,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
extensions: &WeakExtensionRegistry,
) -> Result<impl Iterator<Item = Weak<Extension>>, ExtensionResolutionError> {
let mut used = WeakExtensionRegistry::default();
let used_extensions = &mut used;
match op {
OpType::ExtensionOp(ext) => {
for arg in ext.args_mut() {
Expand Down Expand Up @@ -106,7 +109,7 @@ pub fn resolve_op_types_extensions(
// Ignore optypes that do not store a signature.
OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {}
}
Ok(())
Ok(used.into_iter())
}

/// Update all weak Extension pointers in the [`CustomType`]s inside a signature.
Expand All @@ -115,8 +118,8 @@ pub fn resolve_op_types_extensions(
fn resolve_signature_exts(
node: Option<Node>,
signature: &mut Signature,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// Note that we do not include the signature's `runtime_reqs` here, as those refer
// to _runtime_ requirements that may not be currently present.
Expand All @@ -131,8 +134,8 @@ fn resolve_signature_exts(
fn resolve_type_row_exts<RV: MaybeRV>(
node: Option<Node>,
row: &mut TypeRowBase<RV>,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
for ty in row.iter_mut() {
resolve_type_exts(node, ty, extensions, used_extensions)?;
Expand All @@ -146,8 +149,8 @@ fn resolve_type_row_exts<RV: MaybeRV>(
pub(super) fn resolve_type_exts<RV: MaybeRV>(
node: Option<Node>,
typ: &mut TypeBase<RV>,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match typ.as_type_enum_mut() {
TypeEnum::Extension(custom) => {
Expand Down Expand Up @@ -177,8 +180,8 @@ pub(super) fn resolve_type_exts<RV: MaybeRV>(
pub(super) fn resolve_custom_type_exts(
node: Option<Node>,
custom: &mut CustomType,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
for arg in custom.args_mut() {
resolve_typearg_exts(node, arg, extensions, used_extensions)?;
Expand All @@ -191,8 +194,8 @@ pub(super) fn resolve_custom_type_exts(

// Add the extension to the used extensions registry,
// and update the CustomType with the valid pointer.
used_extensions.register_updated_ref(ext);
custom.update_extension(Arc::downgrade(ext));
used_extensions.register(ext_id.clone(), ext.clone());
custom.update_extension(ext.clone());

Ok(())
}
Expand All @@ -203,8 +206,8 @@ pub(super) fn resolve_custom_type_exts(
pub(super) fn resolve_typearg_exts(
node: Option<Node>,
arg: &mut TypeArg,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match arg {
TypeArg::Type { ty } => resolve_type_exts(node, ty, extensions, used_extensions)?,
Expand All @@ -224,8 +227,8 @@ pub(super) fn resolve_typearg_exts(
pub(super) fn resolve_value_exts(
node: Option<Node>,
value: &mut Value,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
extensions: &WeakExtensionRegistry,
used_extensions: &mut WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match value {
Value::Extension { e } => {
Expand All @@ -246,7 +249,9 @@ pub(super) fn resolve_value_exts(
Value::Function { hugr } => {
// We don't need to add the nested hugr's extensions to the main one here,
// but we run resolution on it independently.
hugr.resolve_extension_defs(extensions)?;
if let Ok(exts) = extensions.try_into() {
hugr.resolve_extension_defs(&exts)?;
}
}
Value::Sum(s) => {
if let SumType::General { rows } = &mut s.sum_type {
Expand Down
Loading

0 comments on commit 207ae65

Please sign in to comment.