diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index 15154d4f2..70d5aa26c 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -605,8 +605,8 @@ pub(crate) mod test { // \-> right -/ \-<--<-/ let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, @@ -887,8 +887,8 @@ pub(crate) mod test { separate: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, @@ -929,8 +929,8 @@ pub(crate) mod test { cfg_builder: &mut CFGBuilder, separate_headers: bool, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which - let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?; + let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2))?; // Nothing here cares which + let const_unit = cfg_builder.add_constant(Const::unary_unit_sum())?; let entry = n_identity( cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 833e456c8..641ef1ae2 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -70,12 +70,8 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant( - &mut self, - constant: ops::Const, - extensions: impl Into>, - ) -> Result { - let const_n = self.add_child_node(NodeType::new(constant, extensions.into()))?; + fn add_constant(&mut self, constant: ops::Const) -> Result { + let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?; Ok(const_n.into()) } @@ -356,20 +352,16 @@ pub trait Dataflow: Container { fn load_const(&mut self, cid: &ConstID) -> Result { let const_node = cid.node(); let nodetype = self.hugr().get_nodetype(const_node); - let input_extensions = nodetype.input_extensions().cloned(); let op: ops::Const = nodetype .op() .clone() .try_into() .expect("ConstID does not refer to Const op."); - let load_n = self.add_dataflow_node( - NodeType::new( - ops::LoadConstant { - datatype: op.const_type().clone(), - }, - input_extensions, - ), + let load_n = self.add_dataflow_op( + ops::LoadConstant { + datatype: op.const_type().clone(), + }, // Constant wire from the constant value node vec![Wire::new(const_node, OutgoingPort::from(0))], )?; @@ -382,12 +374,8 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const( - &mut self, - constant: ops::Const, - extensions: ExtensionSet, - ) -> Result { - let cid = self.add_constant(constant, extensions)?; + fn add_load_const(&mut self, constant: ops::Const) -> Result { + let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 99781ea2c..9de97bc9c 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -385,7 +385,7 @@ mod test { let mut middle_b = cfg_builder .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; let middle = { - let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let c = middle_b.add_load_const(ops::Const::unary_unit_sum())?; let [inw] = middle_b.input_wires_arr(); middle_b.finish_with_outputs(c, [inw])? }; @@ -398,8 +398,7 @@ mod test { #[test] fn test_dom_edge() -> Result<(), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let sum_tuple_const = - cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?; let sum_variants = vec![type_row![]]; let mut entry_b = @@ -427,8 +426,7 @@ mod test { #[test] fn test_non_dom_edge() -> Result<(), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; - let sum_tuple_const = - cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let sum_tuple_const = cfg_builder.add_constant(ops::Const::unary_unit_sum())?; let sum_variants = vec![type_row![]]; let mut middle_b = cfg_builder .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 0238d14a4..1e3441968 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -242,7 +242,7 @@ mod test { "main", FunctionType::new(type_row![NAT], type_row![NAT]).into(), )?; - let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?; + let tru_const = fbuild.add_constant(Const::true_val())?; let _fdef = { let const_wire = fbuild.load_const(&tru_const)?; let [int] = fbuild.input_wires_arr(); diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 9ab71182b..bbcddade7 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -109,10 +109,7 @@ mod test { let build_result: Result = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const( - ConstUsize::new(1).into(), - ExtensionSet::singleton(&PRELUDE_ID), - )?; + let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -148,8 +145,7 @@ mod test { fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; let signature = loop_b.loop_signature()?.clone(); let const_val = Const::true_val(); - let const_wire = - loop_b.add_load_const(Const::true_val(), ExtensionSet::new())?; + let const_wire = loop_b.add_load_const(Const::true_val())?; let lift_node = loop_b.add_dataflow_op( ops::LeafOp::Lift { type_row: vec![const_val.const_type().clone()].into(), @@ -177,10 +173,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const( - ConstUsize::new(2).into(), - ExtensionSet::singleton(&PRELUDE_ID), - )?; + let wire = branch_1.add_load_const(ConstUsize::new(2).into())?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/extension.rs b/src/extension.rs index f18834baa..95b0474ea 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -387,6 +387,16 @@ impl ExtensionSet { self } + /// Returns the union of an arbitrary collection of [ExtensionSet]s + pub fn union_over(sets: impl IntoIterator) -> Self { + // `union` clones the receiver, which we do not need to do here + let mut res = ExtensionSet::new(); + for s in sets { + res.0.extend(s.0) + } + res + } + /// The things in other which are in not in self pub fn missing_from(&self, other: &Self) -> Self { ExtensionSet::from_iter(other.0.difference(&self.0).cloned()) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 0b99789c2..84e22b65a 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -317,15 +317,11 @@ impl UnificationContext { match node_type.io_extensions() { // Input extensions are open None => { - let c = if let Some(sig) = node_type.op_signature() { - let delta = sig.extension_reqs; - if delta.is_empty() { - Constraint::Equal(m_input) - } else { - Constraint::Plus(delta, m_input) - } - } else { + let delta = node_type.op().extension_delta(); + let c = if delta.is_empty() { Constraint::Equal(m_input) + } else { + Constraint::Plus(delta, m_input) }; self.add_constraint(m_output, c); } @@ -703,7 +699,7 @@ impl UnificationContext { }); let (rs, other_ms): (Vec<_>, Vec<_>) = plus_constraints.unzip(); - let solution = rs.iter().fold(ExtensionSet::new(), ExtensionSet::union); + let solution = ExtensionSet::union_over(rs); let unresolved_metas = other_ms .into_iter() .filter(|other_m| m != *other_m) @@ -731,7 +727,7 @@ impl UnificationContext { Constraint::Plus(_, other_m) => solutions.get(&self.resolve(*other_m)), Constraint::Equal(_) => None, }) - .fold(ExtensionSet::new(), |a, b| a.union(b)); + .fold(ExtensionSet::new(), ExtensionSet::union); for m in cc.iter() { self.add_solution(*m, combined_solution.clone()); diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index c5f587975..f96046ba8 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; -use super::{ExtensionRegistry, SignatureError, SignatureFromArgs}; +use super::{ExtensionRegistry, ExtensionSet, SignatureError, SignatureFromArgs}; struct ArrayOpCustom; const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()]; @@ -181,6 +181,10 @@ impl CustomConst for ConstUsize { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } } impl KnownTypeConst for ConstUsize { diff --git a/src/hugr.rs b/src/hugr.rs index 2971885a8..9672f3dbb 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -125,16 +125,9 @@ impl NodeType { /// `None`` if the [Self::input_extensions] is `None`. /// Otherwise, will return Some, with the output extensions computed from the node's delta pub fn io_extensions(&self) -> Option<(&ExtensionSet, ExtensionSet)> { - self.input_extensions.as_ref().map(|e| { - ( - e, - self.op - .dataflow_signature() - .map(|ft| ft.extension_reqs) - .unwrap_or_default() - .union(e), - ) - }) + self.input_extensions + .as_ref() + .map(|e| (e, self.op.extension_delta().union(e))) } /// Gets the underlying [OpType] i.e. without any [input_extensions] diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index d0640048a..c13a4183f 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -142,7 +142,7 @@ impl Rewrite for OutlineCfg { .unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let unit_sum = new_block_bldr - .add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new()) + .add_constant(ops::Const::unary_unit_sum()) .unwrap(); let pred_wire = new_block_bldr.load_const(&unit_sum).unwrap(); new_block_bldr diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 135f62162..1ca02b4ac 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -477,14 +477,15 @@ mod test { .unwrap() .into(); let just_list = TypeRow::from(vec![listy.clone()]); - let exset = ExtensionSet::singleton(&collections::EXTENSION_NAME); let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]); let mut cfg = CFGBuilder::new( - FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset), + // One might expect an extension_delta of "collections" here, but push/pop + // have an empty delta themselves, pending https://github.com/CQCL/hugr/issues/388 + FunctionType::new_endo(just_list.clone()), )?; - let pred_const = cfg.add_constant(ops::Const::unary_unit_sum(), None)?; + let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?; let entry = single_node_block(&mut cfg, pop, &pred_const, true)?; let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?; diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 8b1545049..dc8e9add2 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -888,7 +888,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { let empty_list = Value::Extension { c: (Box::new(collections::ListValue::new(vec![])),), }; - let cst = def.add_load_const(Const::new(empty_list, list_of_var)?, just_colns)?; + let cst = def.add_load_const(Const::new(empty_list, list_of_var)?)?; let res = def.finish_hugr_with_outputs([cst], ®); assert_matches!( res.unwrap_err(), diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index a2a3274b9..97fb50861 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -132,18 +132,22 @@ fn value_types() { #[rustversion::since(1.75)] // uses impl in return position #[test] fn static_targets() { - use crate::extension::prelude::{ConstUsize, USIZE_T}; + use crate::extension::{ + prelude::{ConstUsize, PRELUDE_ID, USIZE_T}, + ExtensionSet, + }; use itertools::Itertools; + let mut dfg = DFGBuilder::new( + FunctionType::new(type_row![], type_row![USIZE_T]) + .with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)), + ) + .unwrap(); - let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap(); - - let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap(); + let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap(); let load = dfg.load_const(&c).unwrap(); - let h = dfg - .finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY) - .unwrap(); + let h = dfg.finish_prelude_hugr_with_outputs([load]).unwrap(); assert_eq!(h.static_source(load.node()), Some(c.node())); diff --git a/src/ops.rs b/src/ops.rs index 7deb6328d..9c44f7f00 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -9,6 +9,7 @@ pub mod leaf; pub mod module; pub mod tag; pub mod validate; +use crate::extension::ExtensionSet; use crate::types::{EdgeKind, FunctionType, Type}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; @@ -278,6 +279,13 @@ pub trait OpTrait { fn dataflow_signature(&self) -> Option { None } + + /// The delta between the input extensions specified for a node, + /// and the output extensions calculated for that node + fn extension_delta(&self) -> ExtensionSet { + ExtensionSet::new() + } + /// The edge kind for the non-dataflow or constant inputs of the operation, /// not described by the signature. /// diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 5f87c96d2..6db0006b3 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -1,6 +1,7 @@ //! Constant value definitions. use crate::{ + extension::ExtensionSet, types::{ConstTypeError, EdgeKind, Type, TypeRow}, values::{CustomConst, KnownTypeConst, Value}, }; @@ -96,6 +97,10 @@ impl OpTrait for Const { self.value.description() } + fn extension_delta(&self) -> ExtensionSet { + self.value.extension_reqs() + } + fn tag(&self) -> OpTag { ::TAG } @@ -156,22 +161,19 @@ mod test { type_row![], TypeRow::from(vec![pred_ty.clone()]), ))?; - let c = b.add_constant( - Const::tuple_sum( - 0, - Value::tuple([CustomTestValue(TypeBound::Eq).into(), serialized_float(5.1)]), - pred_rows.clone(), - )?, - ExtensionSet::new(), - )?; + let c = b.add_constant(Const::tuple_sum( + 0, + Value::tuple([ + CustomTestValue(TypeBound::Eq, ExtensionSet::new()).into(), + serialized_float(5.1), + ]), + pred_rows.clone(), + )?)?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); let mut b = DFGBuilder::new(FunctionType::new(type_row![], TypeRow::from(vec![pred_ty])))?; - let c = b.add_constant( - Const::tuple_sum(1, Value::unit(), pred_rows)?, - ExtensionSet::new(), - )?; + let c = b.add_constant(Const::tuple_sum(1, Value::unit(), pred_rows)?)?; let w = b.load_const(&c)?; b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); @@ -233,7 +235,12 @@ mod test { ex_id.clone(), TypeBound::Eq, ); - let val: Value = CustomSerialized::new(typ_int.clone(), YamlValue::Number(6.into())).into(); + let val: Value = CustomSerialized::new( + typ_int.clone(), + YamlValue::Number(6.into()), + ExtensionSet::singleton(&ex_id), + ) + .into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Eq); classic_t.check_type(&val).unwrap(); diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index d2121cd36..afd0ef5f8 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -228,6 +228,10 @@ impl OpTrait for Case { "A case node inside a conditional" } + fn extension_delta(&self) -> ExtensionSet { + self.signature.extension_reqs.clone() + } + fn tag(&self) -> OpTag { ::TAG } diff --git a/src/ops/dataflow.rs b/src/ops/dataflow.rs index bf4fd83a9..d830fd0d2 100644 --- a/src/ops/dataflow.rs +++ b/src/ops/dataflow.rs @@ -115,6 +115,9 @@ impl OpTrait for T { fn dataflow_signature(&self) -> Option { Some(DataflowOpTrait::signature(self)) } + fn extension_delta(&self) -> ExtensionSet { + DataflowOpTrait::signature(self).extension_reqs.clone() + } fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) } diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 32b7815ef..582c0eee7 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -3,7 +3,7 @@ use smol_str::SmolStr; use crate::{ - extension::ExtensionId, + extension::{ExtensionId, ExtensionSet}, types::{CustomCheckFailure, CustomType, Type, TypeBound}, values::{CustomConst, KnownTypeConst}, Extension, @@ -66,6 +66,10 @@ impl CustomConst for ConstF64 { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&EXTENSION_ID) + } } /// Extension for basic floating-point types. diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index f45d93964..ac79215f0 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -5,7 +5,7 @@ use std::num::NonZeroU64; use smol_str::SmolStr; use crate::{ - extension::ExtensionId, + extension::{ExtensionId, ExtensionSet}, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -161,6 +161,10 @@ impl CustomConst for ConstIntU { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&EXTENSION_ID) + } } #[typetag::serde] @@ -180,6 +184,10 @@ impl CustomConst for ConstIntS { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&EXTENSION_ID) + } } /// Extension for basic integer types. diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 8b6e227ea..a78f4793a 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, TypeDef, TypeDefBound}, + extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound}, types::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, @@ -66,6 +66,11 @@ impl CustomConst for ListValue { fn equal_consts(&self, other: &dyn CustomConst) -> bool { crate::values::downcast_equal_consts(self, other) } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) + .union(&ExtensionSet::singleton(&EXTENSION_NAME)) + } } const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; diff --git a/src/values.rs b/src/values.rs index 428654066..17d173a00 100644 --- a/src/values.rs +++ b/src/values.rs @@ -8,6 +8,7 @@ use std::any::Any; use downcast_rs::{impl_downcast, Downcast}; use smol_str::SmolStr; +use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::{Hugr, HugrView}; @@ -115,6 +116,16 @@ impl Value { None } } + + /// The Extensions that must be supported to handle the value at runtime + pub fn extension_reqs(&self) -> ExtensionSet { + match self { + Value::Extension { c } => c.0.extension_reqs().clone(), + Value::Function { .. } => ExtensionSet::new(), // no extensions reqd to load Hugr (only to run) + Value::Tuple { vs } => ExtensionSet::union_over(vs.iter().map(Value::extension_reqs)), + Value::Sum { value, .. } => value.extension_reqs(), + } + } } impl From for Value { @@ -134,6 +145,13 @@ pub trait CustomConst: /// An identifier for the constant. fn name(&self) -> SmolStr; + /// The extension(s) defining the custom value + /// (a set to allow, say, a [List] of [USize]) + /// + /// [List]: crate::std_extensions::collections::LIST_TYPENAME + /// [USize]: crate::extension::prelude::USIZE_T + fn extension_reqs(&self) -> ExtensionSet; + /// Check the value is a valid instance of the provided type. fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure>; @@ -184,12 +202,17 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); pub struct CustomSerialized { typ: CustomType, value: serde_yaml::Value, + extensions: ExtensionSet, } impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new(typ: CustomType, value: serde_yaml::Value) -> Self { - Self { typ, value } + pub fn new(typ: CustomType, value: serde_yaml::Value, extensions: ExtensionSet) -> Self { + Self { + typ, + value, + extensions, + } } } @@ -213,6 +236,10 @@ impl CustomConst for CustomSerialized { fn equal_consts(&self, other: &dyn CustomConst) -> bool { Some(self) == other.downcast_ref() } + + fn extension_reqs(&self) -> ExtensionSet { + self.extensions.clone() + } } impl PartialEq for dyn CustomConst { @@ -227,7 +254,7 @@ pub(crate) mod test { use super::*; use crate::builder::test::simple_dfg_hugr; - use crate::std_extensions::arithmetic::float_types::FLOAT64_CUSTOM_TYPE; + use crate::std_extensions::arithmetic::float_types::{self, FLOAT64_CUSTOM_TYPE}; use crate::type_row; use crate::types::{FunctionType, Type, TypeBound}; @@ -235,7 +262,7 @@ pub(crate) mod test { /// A custom constant value used in testing that purports to be an instance /// of a custom type with a specific type bound. - pub(crate) struct CustomTestValue(pub TypeBound); + pub(crate) struct CustomTestValue(pub TypeBound, pub ExtensionSet); #[typetag::serde] impl CustomConst for CustomTestValue { fn name(&self) -> SmolStr { @@ -251,12 +278,17 @@ pub(crate) mod test { )) } } + + fn extension_reqs(&self) -> ExtensionSet { + self.1.clone() + } } pub(crate) fn serialized_float(f: f64) -> Value { Value::custom(CustomSerialized { typ: FLOAT64_CUSTOM_TYPE, value: serde_yaml::Value::Number(f.into()), + extensions: ExtensionSet::singleton(&float_types::EXTENSION_ID), }) }