Skip to content

Commit

Permalink
feat!: Update extension pointers in customConsts
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 12, 2024
1 parent 080eaae commit 23f8533
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 10 deletions.
8 changes: 8 additions & 0 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use super::resolution::{resolve_type_extensions, ExtensionResolutionError};
use super::ExtensionRegistry;

mod unwrap_builder;
Expand Down Expand Up @@ -504,6 +505,13 @@ impl CustomConst for ConstExternalSymbol {
self.typ.clone()
}

fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
resolve_type_extensions(&mut self.typ, extensions)
}

fn validate(&self) -> Result<(), CustomCheckFailure> {
if self.symbol.is_empty() {
Err(CustomCheckFailure::Message(
Expand Down
41 changes: 39 additions & 2 deletions hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,53 @@ mod types_mut;
pub(crate) use ops::{collect_op_extension, resolve_op_extensions};
pub(crate) use types::{collect_op_types_extensions, collect_signature_exts};
pub(crate) use types_mut::resolve_op_types_extensions;
use types_mut::{resolve_type_exts, resolve_typearg_exts, resolve_value_exts};

use derive_more::{Display, Error, From};

use super::{Extension, ExtensionId, ExtensionRegistry, ExtensionSet};
use crate::ops::constant::ValueName;
use crate::ops::custom::OpaqueOpError;
use crate::ops::{NamedOp, OpName, OpType};
use crate::types::{FuncTypeBase, MaybeRV, TypeName};
use crate::ops::{NamedOp, OpName, OpType, Value};
use crate::types::{FuncTypeBase, MaybeRV, TypeArg, TypeBase, TypeName};
use crate::Node;

/// Update all weak Extension pointers inside a type.
pub fn resolve_type_extensions<RV: MaybeRV>(
typ: &mut TypeBase<RV>,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_type_exts(node, typ, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a type argument.
pub fn resolve_typearg_extensions(
arg: &mut TypeArg,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_typearg_exts(node, arg, extensions, &mut used_extensions)
}

/// Update all weak Extension pointers inside a constant value.
pub fn resolve_value_extensions(
value: &mut Value,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
// This public export is used for implementing `CustomConst::update_extensions`, so we don't need the full internal API here.
// TODO: Make `node` optional in `ExtensionResolutionError`
let node: Node = portgraph::NodeIndex::new(0).into();
let mut used_extensions = ExtensionRegistry::default();
resolve_value_exts(node, value, extensions, &mut used_extensions)
}

/// Errors that can occur during extension resolution.
#[derive(Debug, Display, Clone, Error, From, PartialEq)]
#[non_exhaustive]
Expand Down
13 changes: 7 additions & 6 deletions hugr-core/src/extension/resolution/types_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ fn resolve_type_row_exts<RV: MaybeRV>(
/// Update all weak Extension pointers in the [`CustomType`]s inside a type.
///
/// Adds the extensions used in the type to the `used_extensions` registry.
fn resolve_type_exts<RV: MaybeRV>(
pub(super) fn resolve_type_exts<RV: MaybeRV>(
node: Node,
typ: &mut TypeBase<RV>,
extensions: &ExtensionRegistry,
Expand Down Expand Up @@ -191,7 +191,7 @@ fn resolve_type_exts<RV: MaybeRV>(
/// Update all weak Extension pointers in the [`CustomType`]s inside a type arg.
///
/// Adds the extensions used in the type to the `used_extensions` registry.
fn resolve_typearg_exts(
pub(super) fn resolve_typearg_exts(
node: Node,
arg: &mut TypeArg,
extensions: &ExtensionRegistry,
Expand All @@ -212,17 +212,18 @@ fn resolve_typearg_exts(
/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Value`].
///
/// Adds the extensions used in the row to the `used_extensions` registry.
fn resolve_value_exts(
pub(super) fn resolve_value_exts(
node: Node,
value: &mut Value,
extensions: &ExtensionRegistry,
used_extensions: &mut ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
match value {
Value::Extension { e } => {
// We expect that the `CustomConst::get_type` binary calls always
// return types with valid extensions.
// So here we just collect the used extensions.
e.value_mut().update_extensions(extensions)?;

// We expect that the `CustomConst::get_type` binary calls
// return types with valid extensions after we call `update_extensions`.
let typ = e.get_type();
let mut missing = ExtensionSet::new();
collect_type_exts(&typ, used_extensions, &mut missing);
Expand Down
18 changes: 17 additions & 1 deletion hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ impl OpaqueValue {
self.v.as_ref()
}

/// Returns a reference to the internal [`CustomConst`].
pub(crate) fn value_mut(&mut self) -> &mut dyn CustomConst {
self.v.as_mut()
}

delegate! {
to self.value() {
/// Returns the type of the internal [`CustomConst`].
Expand Down Expand Up @@ -572,7 +577,8 @@ mod test {
use crate::builder::inout_sig;
use crate::builder::test::simple_dfg_hugr;
use crate::extension::prelude::{bool_t, usize_custom_t};
use crate::extension::PRELUDE;
use crate::extension::resolution::{resolve_typearg_extensions, ExtensionResolutionError};
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::{
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
Expand Down Expand Up @@ -604,6 +610,16 @@ mod test {
ExtensionSet::singleton(self.0.extension().clone())
}

fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
for arg in self.0.args_mut() {
resolve_typearg_extensions(arg, extensions)?;
}
Ok(())
}

fn get_type(&self) -> Type {
self.0.clone().into()
}
Expand Down
22 changes: 21 additions & 1 deletion hugr-core/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use std::hash::{Hash, Hasher};
use downcast_rs::{impl_downcast, Downcast};
use thiserror::Error;

use crate::extension::ExtensionSet;
use crate::extension::resolution::{resolve_type_extensions, ExtensionResolutionError};
use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::macros::impl_box_clone;
use crate::types::{CustomCheckFailure, Type};
use crate::IncomingPort;
Expand Down Expand Up @@ -84,6 +85,19 @@ pub trait CustomConst:
false
}

/// Update the extensions associated with the internal values.
///
/// This is used to ensure that any extension reference [`CustomConst::get_type`] remains
/// valid when serializing and deserializing the constant.
///
/// See the helper methods in [`crate::extension::resolution`].
fn update_extensions(
&mut self,
_extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
Ok(())
}

/// Report the type.
fn get_type(&self) -> Type;
}
Expand Down Expand Up @@ -300,6 +314,12 @@ impl CustomConst for CustomSerialized {
fn extension_reqs(&self) -> ExtensionSet {
self.extensions.clone()
}
fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
resolve_type_extensions(&mut self.typ, extensions)
}
fn get_type(&self) -> Type {
self.typ.clone()
}
Expand Down
13 changes: 13 additions & 0 deletions hugr-core/src/std_extensions/collections/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use serde::{Deserialize, Serialize};
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::prelude::{either_type, option_type, usize_t};
use crate::extension::resolution::{
resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError,
};
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE};
use crate::ops::constant::{maybe_hash_values, TryHash, ValueName};
Expand Down Expand Up @@ -126,6 +129,16 @@ impl CustomConst for ListValue {
ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs))
.union(EXTENSION_ID.into())
}

fn update_extensions(
&mut self,
extensions: &ExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
for val in &mut self.0 {
resolve_value_extensions(val, extensions)?;
}
resolve_type_extensions(&mut self.1, extensions)
}
}

/// A list operation
Expand Down

0 comments on commit 23f8533

Please sign in to comment.