diff --git a/Cargo.lock b/Cargo.lock index 64ecd0749c..36e6af0eef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2451,7 +2451,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.2", + "getrandom", ] [[package]] diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 0f7f078ec7..cc6b1536fc 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -627,7 +627,7 @@ pub(crate) mod test { FunctionBuilder::new( "bad_eval", PolyFuncType::new( - [TypeParam::new_list(TypeBound::Copyable)], + [TypeParam::new_list_type(TypeBound::Copyable)], Signature::new( Type::new_function(FuncValueType::new(usize_t(), tv.clone())), vec![], diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 9a392dfaaf..ff6fac8b43 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,6 +1,7 @@ //! Exporting HUGR graphs to their `hugr-model` representation. use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; +use crate::types::type_param::Term; use crate::{ Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port, extension::{ExtensionId, OpDef, SignatureFunc}, @@ -14,9 +15,7 @@ use crate::{ }, types::{ CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, - TypeArg, TypeBase, TypeBound, TypeEnum, - type_param::{TypeArgVariable, TypeParam}, - type_row::TypeRowBase, + TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase, }, }; @@ -385,7 +384,7 @@ impl<'a> Context<'a> { let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); - args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); + args.extend(call.type_args.iter().map(|arg| self.export_term(arg, None))); let args = args.into_bump_slice(); let func = self.make_term(table::Term::Apply(symbol, args)); @@ -401,7 +400,7 @@ impl<'a> Context<'a> { let node = self.connected_function(node).unwrap(); let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); - args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); + args.extend(load.type_args.iter().map(|arg| self.export_term(arg, None))); let args = args.into_bump_slice(); let func = self.make_term(table::Term::Apply(symbol, args)); let runtime_type = self.make_term(table::Term::Wildcard); @@ -464,7 +463,7 @@ impl<'a> Context<'a> { let node = self.export_opdef(op.def()); let params = self .bump - .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None))); let operation = self.make_term(table::Term::Apply(node, params)); table::Operation::Custom(operation) } @@ -473,7 +472,7 @@ impl<'a> Context<'a> { let node = self.make_named_global_ref(op.extension(), op.unqualified_id()); let params = self .bump - .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); + .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None))); let operation = self.make_term(table::Term::Apply(node, params)); table::Operation::Custom(operation) } @@ -806,7 +805,7 @@ impl<'a> Context<'a> { for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param, Some((scope, i as _))); + let r#type = self.export_term(param, Some((scope, i as _))); let param = table::Param { name, r#type }; params.push(param); } @@ -854,40 +853,12 @@ impl<'a> Context<'a> { let args = self .bump - .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p))); + .alloc_slice_fill_iter(t.args().iter().map(|p| self.export_term(p, None))); let term = table::Term::Apply(symbol, args); self.make_term(term) } - pub fn export_type_arg(&mut self, t: &TypeArg) -> table::TermId { - match t { - TypeArg::Type { ty } => self.export_type(ty), - TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()), - TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()), - TypeArg::Float { value } => self.make_term(model::Literal::Float(*value).into()), - TypeArg::Bytes { value } => self.make_term(model::Literal::Bytes(value.clone()).into()), - TypeArg::List { elems } => { - // For now we assume that the sequence is meant to be a list. - let parts = self.bump.alloc_slice_fill_iter( - elems - .iter() - .map(|elem| table::SeqPart::Item(self.export_type_arg(elem))), - ); - self.make_term(table::Term::List(parts)) - } - TypeArg::Tuple { elems } => { - let parts = self.bump.alloc_slice_fill_iter( - elems - .iter() - .map(|elem| table::SeqPart::Item(self.export_type_arg(elem))), - ); - self.make_term(table::Term::Tuple(parts)) - } - TypeArg::Variable { v } => self.export_type_arg_var(v), - } - } - - pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> table::TermId { + pub fn export_type_arg_var(&mut self, var: &TermVar) -> table::TermId { let node = self.local_scope.expect("local variable out of scope"); self.make_term(table::Term::Var(table::VarId(node, var.index() as _))) } @@ -953,19 +924,19 @@ impl<'a> Context<'a> { self.make_term(table::Term::List(parts)) } - /// Exports a `TypeParam` to a term. + /// Exports a term. /// - /// The `var` argument is set when the type parameter being exported is the + /// The `var` argument is set when the term being exported is the /// type of a parameter to a polymorphic definition. In that case we can /// generate a `nonlinear` constraint for the type of runtime types marked as /// `TypeBound::Copyable`. - pub fn export_type_param( + pub fn export_term( &mut self, - t: &TypeParam, + t: &Term, var: Option<(table::NodeId, table::VarIndex)>, ) -> table::TermId { match t { - TypeParam::Type { b } => { + Term::RuntimeType(b) => { if let (Some((node, index)), TypeBound::Copyable) = (var, b) { let term = self.make_term(table::Term::Var(table::VarId(node, index))); let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]); @@ -974,24 +945,46 @@ impl<'a> Context<'a> { self.make_term_apply(model::CORE_TYPE, &[]) } - // This ignores the bound on the natural for now. - TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]), - TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]), - TypeParam::Bytes => self.make_term_apply(model::CORE_BYTES_TYPE, &[]), - TypeParam::Float => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]), - TypeParam::List { param } => { - let item_type = self.export_type_param(param, None); + Term::BoundedNatType { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]), + Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]), + Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]), + Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]), + Term::ListType(item_type) => { + let item_type = self.export_term(item_type, None); self.make_term_apply(model::CORE_LIST_TYPE, &[item_type]) } - TypeParam::Tuple { params } => { - let parts = self.bump.alloc_slice_fill_iter( + Term::TupleType(params) => { + let item_types = self.bump.alloc_slice_fill_iter( params .iter() - .map(|param| table::SeqPart::Item(self.export_type_param(param, None))), + .map(|param| table::SeqPart::Item(self.export_term(param, None))), ); - let types = self.make_term(table::Term::List(parts)); + let types = self.make_term(table::Term::List(item_types)); self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) } + Term::Runtime(ty) => self.export_type(ty), + Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()), + Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()), + Term::Float(value) => self.make_term(model::Literal::Float(*value).into()), + Term::Bytes(value) => self.make_term(model::Literal::Bytes(value.clone()).into()), + Term::List(elems) => { + let parts = self.bump.alloc_slice_fill_iter( + elems + .iter() + .map(|elem| table::SeqPart::Item(self.export_term(elem, None))), + ); + self.make_term(table::Term::List(parts)) + } + Term::Tuple(elems) => { + let parts = self.bump.alloc_slice_fill_iter( + elems + .iter() + .map(|elem| table::SeqPart::Item(self.export_term(elem, None))), + ); + self.make_term(table::Term::Tuple(parts)) + } + Term::Variable(v) => self.export_type_arg_var(v), + Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]), } } diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index bb5034e1b1..925cc6d16f 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -22,9 +22,8 @@ use crate::hugr::IdentList; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; use crate::types::RowVariable; -use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; -use crate::types::{CustomType, TypeBound, TypeName}; -use crate::types::{Signature, TypeNameRef}; +use crate::types::type_param::{Term, TermTypeError, TypeArg, TypeParam}; +use crate::types::{CustomType, Signature, TypeBound, TypeName, TypeNameRef}; mod const_fold; mod op_def; @@ -387,7 +386,10 @@ pub enum SignatureError { ExtensionMismatch(ExtensionId, ExtensionId), /// When the type arguments of the node did not match the params declared by the `OpDef` #[error("Type arguments of node did not match params declared by definition: {0}")] - TypeArgMismatch(#[from] TypeArgError), + TypeArgMismatch(#[from] TermTypeError), + /// A [Term] was not a valid type parameter + #[error("Term {0} is not a valid parameter type")] + InvalidTypeParam(Term), /// Invalid type arguments #[error("Invalid type arguments for operation")] InvalidTypeArgs, diff --git a/hugr-core/src/extension/declarative/types.rs b/hugr-core/src/extension/declarative/types.rs index ebbf628d68..0f80cc1e0c 100644 --- a/hugr-core/src/extension/declarative/types.rs +++ b/hugr-core/src/extension/declarative/types.rs @@ -129,6 +129,6 @@ impl TypeParamDeclaration { _extension: &Extension, _ctx: DeclarationContext<'_>, ) -> Result { - Ok(TypeParam::String) + Ok(TypeParam::StringType) } } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 9c30cbdd47..e7e9f78ba8 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -14,7 +14,7 @@ use super::{ use crate::Hugr; use crate::envelope::serde_with::AsStringEnvelope; use crate::ops::{OpName, OpNameRef}; -use crate::types::type_param::{TypeArg, TypeParam, check_type_args}; +use crate::types::type_param::{TypeArg, TypeParam, check_term_types}; use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; mod serialize_signature_func; @@ -239,7 +239,7 @@ impl SignatureFunc { let static_params = func.static_params(); let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); - check_type_args(static_args, static_params)?; + check_term_types(static_args, static_params)?; temp = func.compute_signature(static_args, def)?; (&temp, other_args) } @@ -347,7 +347,7 @@ impl OpDef { let (static_args, other_args) = args.split_at(min(custom.static_params().len(), args.len())); static_args.iter().try_for_each(|ta| ta.validate(&[]))?; - check_type_args(static_args, custom.static_params())?; + check_term_types(static_args, custom.static_params())?; temp = custom.compute_signature(static_args, self)?; (&temp, other_args) } @@ -357,7 +357,7 @@ impl OpDef { } }; args.iter().try_for_each(|ta| ta.validate(var_decls))?; - check_type_args(args, pf.params())?; + check_term_types(args, pf.params())?; Ok(()) } @@ -553,7 +553,7 @@ pub(super) mod test { use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::ops::OpName; use crate::std_extensions::collections::list; - use crate::types::type_param::{TypeArgError, TypeParam}; + use crate::types::type_param::{TermTypeError, TypeParam}; use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; use crate::{Extension, const_extension_ids}; @@ -656,7 +656,7 @@ pub(super) mod test { const OP_NAME: OpName = OpName::new_inline("Reverse"); let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any); let list_of_var = Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); @@ -678,11 +678,10 @@ pub(super) mod test { reg.validate()?; let e = reg.get(&EXT_ID).unwrap(); - let list_usize = - Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?); + let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?); let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?; let rev = dfg.add_dataflow_op( - e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }]) + e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()]) .unwrap(), dfg.input_wires(), )?; @@ -703,8 +702,8 @@ pub(super) mod test { &self, arg_values: &[TypeArg], ) -> Result { - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - let [TypeArg::BoundedNat { n }] = arg_values else { + const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any); + let [TypeArg::BoundedNat(n)] = arg_values else { return Err(SignatureError::InvalidTypeArgs); }; let n = *n as usize; @@ -718,7 +717,7 @@ pub(super) mod test { } fn static_params(&self) -> &[TypeParam] { - const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()]; MAX_NAT } } @@ -727,7 +726,7 @@ pub(super) mod test { ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?; // Base case, no type variables: - let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()]; + let args = [TypeArg::BoundedNat(3), usize_t().into()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( @@ -740,7 +739,7 @@ pub(super) mod test { // Second arg may be a variable (substitutable) let tyvar = Type::new_var_use(0, TypeBound::Copyable); let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; + let args = [TypeArg::BoundedNat(3), tyvar.clone().into()]; assert_eq!( def.compute_signature(&args), Ok(Signature::new( @@ -761,7 +760,7 @@ pub(super) mod test { ); // First arg must be concrete, not a variable - let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); + let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap()); let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()]; // We can't prevent this from getting into our compute_signature implementation: assert_eq!( @@ -798,7 +797,7 @@ pub(super) mod test { extension_ref, )?; let tv = Type::new_var_use(0, TypeBound::Copyable); - let args = [TypeArg::Type { ty: tv.clone() }]; + let args = [tv.clone().into()]; let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv))); @@ -807,9 +806,9 @@ pub(super) mod test { assert_eq!( def.compute_signature(&[arg.clone()]), Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: TypeBound::Any.into(), - arg + TermTypeError::TypeMismatch { + type_: TypeBound::Any.into(), + term: arg, } )) ); diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index aa96b380e1..c7879d068d 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -18,8 +18,8 @@ use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::{NamedOp, Value}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound, - TypeName, TypeRV, TypeRow, TypeRowRV, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Term, Type, + TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, }; use crate::utils::sorted_consts; use crate::{Extension, type_row}; @@ -101,7 +101,7 @@ lazy_static! { PANIC_OP_ID, "Panic with input error".to_string(), PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Any), TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new( vec![TypeRV::new_extension(error_type.clone()), TypeRV::new_row_var_use(0, TypeBound::Any)], vec![TypeRV::new_row_var_use(1, TypeBound::Any)], @@ -115,7 +115,7 @@ lazy_static! { EXIT_OP_ID, "Exit with input error".to_string(), PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Any), TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new( vec![TypeRV::new_extension(error_type), TypeRV::new_row_var_use(0, TypeBound::Any)], vec![TypeRV::new_row_var_use(1, TypeBound::Any)], @@ -615,7 +615,7 @@ impl MakeOpDef for TupleOpDef { let rv = TypeRV::new_row_var_use(0, TypeBound::Any); let tuple_type = TypeRV::new_tuple(vec![rv.clone()]); - let param = TypeParam::new_list(TypeBound::Any); + let param = TypeParam::new_list_type(TypeBound::Any); match self { TupleOpDef::MakeTuple => { PolyFuncTypeRV::new([param], FuncValueType::new(rv, tuple_type)) @@ -678,13 +678,13 @@ impl MakeExtensionOp for MakeTuple { if def != TupleOpDef::MakeTuple { return Err(OpLoadError::NotMember(ext_op.unqualified_id().to_string()))?; } - let [TypeArg::List { elems }] = ext_op.args() else { + let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Type { ty } => Ok(ty.clone()), + TypeArg::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); @@ -692,13 +692,7 @@ impl MakeExtensionOp for MakeTuple { } fn type_args(&self) -> Vec { - vec![TypeArg::List { - elems: self - .0 - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }] + vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] } } @@ -739,27 +733,21 @@ impl MakeExtensionOp for UnpackTuple { if def != TupleOpDef::UnpackTuple { return Err(OpLoadError::NotMember(ext_op.unqualified_id().to_string()))?; } - let [TypeArg::List { elems }] = ext_op.args() else { + let [Term::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Type { ty } => Ok(ty.clone()), + Term::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); Ok(Self(tys?.into())) } - fn type_args(&self) -> Vec { - vec![TypeArg::List { - elems: self - .0 - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }] + fn type_args(&self) -> Vec { + vec![Term::new_list(self.0.iter().map(|t| t.clone().into()))] } } @@ -863,14 +851,14 @@ impl MakeExtensionOp for Noop { Self: Sized, { let _def = NoopDef::from_def(ext_op.def())?; - let [TypeArg::Type { ty }] = ext_op.args() else { + let [TypeArg::Runtime(ty)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; Ok(Self(ty.clone())) } fn type_args(&self) -> Vec { - vec![TypeArg::Type { ty: self.0.clone() }] + vec![self.0.clone().into()] } } @@ -910,7 +898,7 @@ impl MakeOpDef for BarrierDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { PolyFuncTypeRV::new( - vec![TypeParam::new_list(TypeBound::Any)], + vec![TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any)), ) .into() @@ -969,13 +957,13 @@ impl MakeExtensionOp for Barrier { { let _def = BarrierDef::from_def(ext_op.def())?; - let [TypeArg::List { elems }] = ext_op.args() else { + let [TypeArg::List(elems)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs)?; }; let tys: Result, _> = elems .iter() .map(|a| match a { - TypeArg::Type { ty } => Ok(ty.clone()), + TypeArg::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs), }) .collect(); @@ -985,13 +973,9 @@ impl MakeExtensionOp for Barrier { } fn type_args(&self) -> Vec { - vec![TypeArg::List { - elems: self - .type_row - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }] + vec![TypeArg::new_list( + self.type_row.iter().map(|t| t.clone().into()), + )] } } @@ -1009,6 +993,7 @@ impl MakeRegisteredOp for Barrier { mod test { use crate::builder::inout_sig; use crate::std_extensions::arithmetic::float_types::{ConstF64, float64_type}; + use crate::types::Term; use crate::{ Hugr, Wire, builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig}, @@ -1132,9 +1117,8 @@ mod test { let err = b.add_load_value(error_val); - const TYPE_ARG_NONE: TypeArg = TypeArg::List { elems: vec![] }; let op = PRELUDE - .instantiate_extension_op(&EXIT_OP_ID, [TYPE_ARG_NONE, TYPE_ARG_NONE]) + .instantiate_extension_op(&EXIT_OP_ID, [Term::new_list([]), Term::new_list([])]) .unwrap(); b.add_dataflow_op(op, [err]).unwrap(); @@ -1146,10 +1130,8 @@ mod test { /// test the panic operation with input and output wires fn test_panic_with_io() { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; - let type_arg_2q: TypeArg = TypeArg::List { - elems: vec![type_arg_q.clone(), type_arg_q], - }; + let type_arg_q: Term = qb_t().into(); + let type_arg_2q: Term = Term::new_list([type_arg_q.clone(), type_arg_q]); let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs index 9ea231e1bb..ca00c713fd 100644 --- a/hugr-core/src/extension/prelude/generic.rs +++ b/hugr-core/src/extension/prelude/generic.rs @@ -74,7 +74,7 @@ impl MakeOpDef for LoadNatDef { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { let usize_t: Type = usize_custom_t(_extension_ref).into(); - let params = vec![TypeParam::max_nat()]; + let params = vec![TypeParam::max_nat_type()]; PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into() } @@ -166,7 +166,7 @@ mod tests { extension::prelude::{ConstUsize, usize_t}, ops::{OpType, constant}, type_row, - types::TypeArg, + types::Term, }; use super::LoadNat; @@ -175,7 +175,7 @@ mod tests { fn test_load_nat() { let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap(); - let arg = TypeArg::BoundedNat { n: 4 }; + let arg = Term::from(4u64); let op = LoadNat::new(arg); let out = b.add_dataflow_op(op.clone(), []).unwrap(); @@ -195,7 +195,7 @@ mod tests { #[test] fn test_load_nat_fold() { - let arg = TypeArg::BoundedNat { n: 5 }; + let arg = Term::from(5u64); let op = LoadNat::new(arg); let optype: OpType = op.into(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 0e7bfbbab8..52f2c5dbf5 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -26,7 +26,7 @@ pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; pub(crate) use types::{collect_op_types_extensions, collect_signature_exts, collect_type_exts}; pub(crate) use types_mut::resolve_op_types_extensions; use types_mut::{ - resolve_custom_type_exts, resolve_type_exts, resolve_typearg_exts, resolve_value_exts, + resolve_custom_type_exts, resolve_term_exts, resolve_type_exts, resolve_value_exts, }; use derive_more::{Display, Error, From}; @@ -63,7 +63,7 @@ pub fn resolve_typearg_extensions( extensions: &WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = WeakExtensionRegistry::default(); - resolve_typearg_exts(None, arg, extensions, &mut used_extensions) + resolve_term_exts(None, arg, extensions, &mut used_extensions) } /// Update all weak Extension pointers inside a constant value. diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index e73dd54fbd..4ac8502ac8 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -25,7 +25,7 @@ use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::std_extensions::std_reg; use crate::types::type_param::TypeParam; -use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; +use crate::types::{PolyFuncType, Signature, Type, TypeBound}; use crate::{Extension, Hugr, HugrView, type_row}; #[rstest] @@ -333,12 +333,12 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) { #[rstest] fn resolve_call() { let dummy_fn_sig = PolyFuncType::new( - vec![TypeParam::Type { b: TypeBound::Any }], + vec![TypeParam::RuntimeType(TypeBound::Any)], Signature::new(vec![], vec![bool_t()]), ); - let generic_type_1 = TypeArg::Type { ty: float64_type() }; - let generic_type_2 = TypeArg::Type { ty: int_type(6) }; + let generic_type_1 = float64_type().into(); + let generic_type_2 = int_type(6).into(); let expected_exts = [ float_types::EXTENSION_ID.clone(), int_types::EXTENSION_ID.clone(), diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index bd16941099..6f5799790a 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -11,7 +11,7 @@ use crate::Node; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::ops::{DataflowOpTrait, OpType, Value}; use crate::types::type_row::TypeRowBase; -use crate::types::{FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; +use crate::types::{FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; /// Collects every extension used to define the types in an operation. /// @@ -38,7 +38,7 @@ pub(crate) fn collect_op_types_extensions( match op { OpType::ExtensionOp(ext) => { for arg in ext.args() { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } collect_signature_exts(&ext.signature(), &mut used, &mut missing); } @@ -55,7 +55,7 @@ pub(crate) fn collect_op_types_extensions( collect_signature_exts(c.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&c.instantiation, &mut used, &mut missing); for arg in &c.type_args { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), @@ -64,13 +64,13 @@ pub(crate) fn collect_op_types_extensions( collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); for arg in &lf.type_args { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } } OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing), OpType::OpaqueOp(op) => { for arg in op.args() { - collect_typearg_exts(arg, &mut used, &mut missing); + collect_term_exts(arg, &mut used, &mut missing); } collect_signature_exts(&op.signature(), &mut used, &mut missing); } @@ -172,7 +172,7 @@ pub(crate) fn collect_type_exts( match typ.as_type_enum() { TypeEnum::Extension(custom) => { for arg in custom.args() { - collect_typearg_exts(arg, used_extensions, missing_extensions); + collect_term_exts(arg, used_extensions, missing_extensions); } let ext_ref = custom.extension_ref(); // Check if the extension reference is still valid. @@ -202,34 +202,50 @@ pub(crate) fn collect_type_exts( } } -/// Collect the Extension pointers in the [`CustomType`]s inside a type argument. +/// Collect the Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// # Attributes /// -/// - `arg`: The type argument to collect the extensions from. +/// - `term`: The term argument to collect the extensions from. /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -pub(super) fn collect_typearg_exts( - arg: &TypeArg, +pub(super) fn collect_term_exts( + term: &Term, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - match arg { - TypeArg::Type { ty } => collect_type_exts(ty, used_extensions, missing_extensions), - TypeArg::List { elems } => { + match term { + Term::Runtime(ty) => collect_type_exts(ty, used_extensions, missing_extensions), + Term::List(elems) => { for elem in elems.iter() { - collect_typearg_exts(elem, used_extensions, missing_extensions); + collect_term_exts(elem, used_extensions, missing_extensions); } } - TypeArg::Tuple { elems } => { + Term::Tuple(elems) => { for elem in elems.iter() { - collect_typearg_exts(elem, used_extensions, missing_extensions); + collect_term_exts(elem, used_extensions, missing_extensions); } } - // We ignore the `TypeArg::Extension` case, as it is not required to - // **define** the hugr. - _ => {} + Term::ListType(item_type) => { + collect_term_exts(item_type, used_extensions, missing_extensions) + } + Term::TupleType(item_types) => { + for item_type in item_types { + collect_term_exts(item_type, used_extensions, missing_extensions); + } + } + Term::Variable(_) + | Term::RuntimeType(_) + | Term::StaticType + | Term::BoundedNatType(_) + | Term::StringType + | Term::BytesType + | Term::FloatType + | Term::BoundedNat(_) + | Term::String(_) + | Term::Bytes(_) + | Term::Float(_) => {} } } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 2840665f5e..e54a21e5ac 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -10,7 +10,7 @@ use super::{ExtensionResolutionError, WeakExtensionRegistry}; use crate::extension::ExtensionSet; use crate::ops::{OpType, Value}; use crate::types::type_row::TypeRowBase; -use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; +use crate::types::{CustomType, FuncTypeBase, MaybeRV, SumType, Term, TypeBase, TypeEnum}; use crate::{Extension, Node}; /// Replace the dangling extension pointer in the [`CustomType`]s inside an @@ -30,7 +30,7 @@ pub fn resolve_op_types_extensions( match op { OpType::ExtensionOp(ext) => { for arg in ext.args_mut() { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } resolve_signature_exts(node, ext.signature_mut(), extensions, used_extensions)?; } @@ -61,7 +61,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; for arg in &mut c.type_args { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } } OpType::CallIndirect(c) => { @@ -74,7 +74,7 @@ pub fn resolve_op_types_extensions( resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut lf.instantiation, extensions, used_extensions)?; for arg in &mut lf.type_args { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } } OpType::DFG(dfg) => { @@ -82,7 +82,7 @@ pub fn resolve_op_types_extensions( } OpType::OpaqueOp(op) => { for arg in op.args_mut() { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } resolve_signature_exts(node, op.signature_mut(), extensions, used_extensions)?; } @@ -195,7 +195,7 @@ pub(super) fn resolve_custom_type_exts( used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { for arg in custom.args_mut() { - resolve_typearg_exts(node, arg, extensions, used_extensions)?; + resolve_term_exts(node, arg, extensions, used_extensions)?; } let ext_id = custom.extension(); @@ -211,28 +211,46 @@ pub(super) fn resolve_custom_type_exts( Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type arg. +/// Update all weak Extension pointers in the [`CustomType`]s inside a [`Term`]. /// /// Adds the extensions used in the type to the `used_extensions` registry. -pub(super) fn resolve_typearg_exts( +pub(super) fn resolve_term_exts( node: Option, - arg: &mut TypeArg, + term: &mut Term, extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - match arg { - TypeArg::Type { ty } => resolve_type_exts(node, ty, extensions, used_extensions)?, - TypeArg::List { elems } => { + match term { + Term::Runtime(ty) => resolve_type_exts(node, ty, extensions, used_extensions)?, + Term::List(elems) => { for elem in elems.iter_mut() { - resolve_typearg_exts(node, elem, extensions, used_extensions)?; + resolve_term_exts(node, elem, extensions, used_extensions)?; } } - TypeArg::Tuple { elems } => { + Term::Tuple(elems) => { for elem in elems.iter_mut() { - resolve_typearg_exts(node, elem, extensions, used_extensions)?; + resolve_term_exts(node, elem, extensions, used_extensions)?; } } - _ => {} + Term::ListType(item_type) => { + resolve_term_exts(node, item_type, extensions, used_extensions)?; + } + Term::TupleType(item_types) => { + for item_type in item_types.iter_mut() { + resolve_term_exts(node, item_type, extensions, used_extensions)?; + } + } + Term::Variable(_) + | Term::RuntimeType(_) + | Term::StaticType + | Term::BoundedNatType(_) + | Term::StringType + | Term::BytesType + | Term::FloatType + | Term::BoundedNat(_) + | Term::String(_) + | Term::Bytes(_) + | Term::Float(_) => {} } Ok(()) } diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index bf013ba5dc..8685b63325 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -308,7 +308,10 @@ impl From for OpType { mod test { use std::sync::Arc; - use crate::{const_extension_ids, type_row, types::Signature}; + use crate::{ + const_extension_ids, type_row, + types::{Signature, Term}, + }; use super::*; use lazy_static::lazy_static; @@ -393,7 +396,7 @@ mod test { assert_eq!(o.instantiate(&[]), Ok(o.clone())); assert_eq!( - o.instantiate(&[TypeArg::BoundedNat { n: 1 }]), + o.instantiate(&[Term::from(1u64)]), Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)) ); } diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index fceb336b2f..42738badb8 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -6,7 +6,7 @@ use super::{Extension, ExtensionId, SignatureError}; use crate::types::{CustomType, TypeName, least_upper_bound}; -use crate::types::type_param::{TypeArg, check_type_args}; +use crate::types::type_param::{TypeArg, check_term_types}; use crate::types::type_param::TypeParam; @@ -79,7 +79,7 @@ pub struct TypeDef { impl TypeDef { /// Check provided type arguments are valid against parameters. pub fn check_args(&self, args: &[TypeArg]) -> Result<(), SignatureError> { - check_type_args(args, &self.params).map_err(SignatureError::TypeArgMismatch) + check_term_types(args, &self.params).map_err(SignatureError::TypeArgMismatch) } /// Check [`CustomType`] is a valid instantiation of this definition. @@ -102,7 +102,7 @@ impl TypeDef { )); } - check_type_args(custom.type_args(), &self.params)?; + check_term_types(custom.type_args(), &self.params)?; let calc_bound = self.bound(custom.args()); if calc_bound == custom.bound() { @@ -123,7 +123,7 @@ impl TypeDef { /// valid instances of the type parameters. pub fn instantiate(&self, args: impl Into>) -> Result { let args = args.into(); - check_type_args(&args, &self.params)?; + check_term_types(&args, &self.params)?; let bound = self.bound(&args); Ok(CustomType::new( self.name().clone(), @@ -147,7 +147,7 @@ impl TypeDef { least_upper_bound(indices.iter().map(|i| { let ta = args.get(*i); match ta { - Some(TypeArg::Type { ty: s }) => s.least_upper_bound(), + Some(TypeArg::Runtime(s)) => s.least_upper_bound(), _ => panic!("TypeArg index does not refer to a type."), } })) @@ -241,7 +241,7 @@ mod test { use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::std_extensions::arithmetic::float_types::float64_type; - use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; + use crate::types::type_param::{TermTypeError, TypeParam}; use crate::types::{Signature, Type, TypeBound}; use super::{TypeDef, TypeDefBound}; @@ -250,9 +250,7 @@ mod test { fn test_instantiate_typedef() { let def = TypeDef { name: "MyType".into(), - params: vec![TypeParam::Type { - b: TypeBound::Copyable, - }], + params: vec![TypeParam::RuntimeType(TypeBound::Copyable)], extension: "MyRsrc".try_into().unwrap(), // Dummy extension. Will return `None` when trying to upgrade it into an `Arc`. extension_ref: Default::default(), @@ -260,9 +258,9 @@ mod test { bound: TypeDefBound::FromParams { indices: vec![0] }, }; let typ = Type::new_extension( - def.instantiate(vec![TypeArg::Type { - ty: Type::new_function(Signature::new(vec![], vec![])), - }]) + def.instantiate(vec![ + Type::new_function(Signature::new(vec![], vec![])).into(), + ]) .unwrap(), ); assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); @@ -271,27 +269,24 @@ mod test { // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate([TypeArg::Type { ty: qb_t() }]), + def.instantiate([qb_t().into()]), Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - arg: TypeArg::Type { ty: qb_t() }, - param: TypeBound::Copyable.into() + TermTypeError::TypeMismatch { + term: qb_t().into(), + type_: TypeBound::Copyable.into() } )) ); // Too few arguments: assert_eq!( def.instantiate([]).unwrap_err(), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(0, 1)) + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(0, 1)) ); // Too many arguments: assert_eq!( - def.instantiate([ - TypeArg::Type { ty: float64_type() }, - TypeArg::Type { ty: float64_type() }, - ]) - .unwrap_err(), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) + def.instantiate([float64_type().into(), float64_type().into(),]) + .unwrap_err(), + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) ); } } diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 2b500ed038..13e766f38f 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -458,13 +458,13 @@ fn roundtrip_value(#[case] value: Value) { fn polyfunctype1() -> PolyFuncType { let function_type = Signature::new_endo(type_row![]); - PolyFuncType::new([TypeParam::max_nat()], function_type) + PolyFuncType::new([TypeParam::max_nat_type()], function_type) } fn polyfunctype2() -> PolyFuncTypeRV { let tv0 = TypeRV::new_row_var_use(0, TypeBound::Any); let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); - let params = [TypeBound::Any, TypeBound::Copyable].map(TypeParam::new_list); + let params = [TypeBound::Any, TypeBound::Copyable].map(TypeParam::new_list_type); let inputs = vec![ TypeRV::new_function(FuncValueType::new(tv0.clone(), tv1.clone())), tv0, @@ -479,12 +479,12 @@ fn polyfunctype2() -> PolyFuncTypeRV { #[rstest] #[case(Signature::new_endo(type_row![]).into())] #[case(polyfunctype1())] -#[case(PolyFuncType::new([TypeParam::String], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncType::new([TypeParam::StringType], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeBound::Copyable.into()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncType::new([TypeParam::new_list(TypeBound::Any)], Signature::new_endo(type_row![])))] -#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::new_list_type(TypeBound::Any)], Signature::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::TupleType([TypeBound::Any.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())].into())], Signature::new_endo(type_row![])))] #[case(PolyFuncType::new( - [TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Any)], Signature::new_endo(Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any)))))] fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type); @@ -492,12 +492,12 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[rstest] #[case(FuncValueType::new_endo(type_row![]).into())] -#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeParam::StringType], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(vec![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncTypeRV::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] -#[case(PolyFuncTypeRV::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::TupleType([TypeBound::Any.into(), TypeParam::bounded_nat_type(2.try_into().unwrap())].into())], FuncValueType::new_endo(type_row![])))] #[case(PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any))))] #[case(polyfunctype2())] fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { @@ -514,7 +514,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}]).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat(1)]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 0ac0d225da..2162a2ba3c 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -20,10 +20,10 @@ use crate::ops::handle::NodeHandle; use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::LogicOp; use crate::std_extensions::logic::test::{and_op, or_op}; -use crate::types::type_param::{TypeArg, TypeArgError}; +use crate::types::type_param::{TermTypeError, TypeArg}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, - TypeRow, + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Term, Type, TypeBound, + TypeRV, TypeRow, }; use crate::{Direction, Hugr, IncomingPort, Node, const_extension_ids, test_file, type_row}; @@ -318,7 +318,7 @@ fn invalid_types() { let valid = Type::new_extension(CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: usize_t() }], + vec![usize_t().into()], EXT_ID, TypeBound::Any, &Arc::downgrade(&ext), @@ -330,22 +330,22 @@ fn invalid_types() { // valid is Any, so is not allowed as an element of an outer MyContainer. let element_outside_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: valid.clone() }], + vec![valid.clone().into()], EXT_ID, TypeBound::Any, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(element_outside_bound), - SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: TypeArg::Type { ty: valid } + SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { + type_: TypeBound::Copyable.into(), + term: valid.into() }) ); let bad_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type { ty: usize_t() }], + vec![usize_t().into()], EXT_ID, TypeBound::Copyable, &Arc::downgrade(&ext), @@ -361,9 +361,7 @@ fn invalid_types() { // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. let nested = CustomType::new( "MyContainer", - vec![TypeArg::Type { - ty: Type::new_extension(bad_bound), - }], + vec![Type::new_extension(bad_bound).into()], EXT_ID, TypeBound::Any, &Arc::downgrade(&ext), @@ -378,17 +376,14 @@ fn invalid_types() { let too_many_type_args = CustomType::new( "MyContainer", - vec![ - TypeArg::Type { ty: usize_t() }, - TypeArg::BoundedNat { n: 3 }, - ], + vec![usize_t().into(), 3u64.into()], EXT_ID, TypeBound::Any, &Arc::downgrade(&ext), ); assert_eq!( validate_to_sig_error(too_many_type_args), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(2, 1)) ); } @@ -458,9 +453,7 @@ fn no_nested_funcdefns() -> Result<(), Box> { #[test] fn no_polymorphic_consts() -> Result<(), Box> { use crate::std_extensions::collections::list; - const BOUND: TypeParam = TypeParam::Type { - b: TypeBound::Copyable, - }; + const BOUND: TypeParam = TypeParam::RuntimeType(TypeBound::Copyable); let list_of_var = Type::new_extension( list::EXTENSION .get_type(&list::LIST_TYPENAME) @@ -493,7 +486,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { } pub(crate) fn extension_with_eval_parallel() -> Arc { - let rowp = TypeParam::new_list(TypeBound::Any); + let rowp = TypeParam::new_list_type(TypeBound::Any); Extension::new_test_arc(EXT_ID, |ext, extension_ref| { let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); @@ -523,8 +516,8 @@ pub(crate) fn extension_with_eval_parallel() -> Arc { #[test] fn instantiate_row_variables() -> Result<(), Box> { - fn uint_seq(i: usize) -> TypeArg { - vec![TypeArg::Type { ty: usize_t() }; i].into() + fn uint_seq(i: usize) -> Term { + vec![usize_t().into(); i].into() } let e = extension_with_eval_parallel(); let mut dfb = DFGBuilder::new(inout_sig( @@ -548,10 +541,8 @@ fn instantiate_row_variables() -> Result<(), Box> { Ok(()) } -fn list1ty(t: TypeRV) -> TypeArg { - TypeArg::List { - elems: vec![t.into()], - } +fn list1ty(t: TypeRV) -> Term { + Term::new_list([t.into()]) } #[test] @@ -563,7 +554,7 @@ fn row_variables() -> Result<(), Box> { let mut fb = FunctionBuilder::new( "id", PolyFuncType::new( - [TypeParam::new_list(TypeBound::Any)], + [TypeParam::new_list_type(TypeBound::Any)], Signature::new(inner_ft.clone(), ft_usz), ), )?; diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index e9c614c590..b8cda15463 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -24,8 +24,8 @@ use crate::{ }, types::{ CustomType, FuncTypeBase, MaybeRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, - Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, type_param::TypeParam, - type_row::TypeRowBase, + Term, Type, TypeArg, TypeBase, TypeBound, TypeEnum, TypeName, TypeRow, + type_param::TypeParam, type_row::TypeRowBase, }, }; use fxhash::FxHashMap; @@ -997,7 +997,7 @@ impl<'a> Context<'a> { let type_args = args .iter() - .map(|term| self.import_type_arg(*term)) + .map(|term| self.import_term(*term)) .collect::, _>>()?; self.static_edges.push((*symbol, node_id)); @@ -1020,7 +1020,7 @@ impl<'a> Context<'a> { let func_sig = self.get_func_signature(*symbol)?; let type_args = args .iter() - .map(|term| self.import_type_arg(*term)) + .map(|term| self.import_term(*term)) .collect::, _>>()?; self.static_edges.push((*symbol, node_id)); @@ -1097,7 +1097,7 @@ impl<'a> Context<'a> { let name = self.get_symbol_name(*node)?; let args = params .iter() - .map(|param| self.import_type_arg(*param)) + .map(|param| self.import_term(*param)) .collect::, _>>()?; let (extension, name) = self.import_custom_name(name)?; let signature = self.get_node_signature(node_id)?; @@ -1188,10 +1188,10 @@ impl<'a> Context<'a> { for (index, param) in symbol.params.iter().enumerate() { // NOTE: `PolyFuncType` only has explicit type parameters at present. let bound = self.local_vars[&table::VarId(node, index as _)].bound; - imported_params - .push(self.import_type_param(param.r#type, bound).map_err(|err| { - error_context!(err, "type of parameter `{}`", param.name) - })?); + imported_params.push( + self.import_term_with_bound(param.r#type, bound) + .map_err(|err| error_context!(err, "type of parameter `{}`", param.name))?, + ); } let body = self.import_func_type::(symbol.signature)?; @@ -1200,147 +1200,69 @@ impl<'a> Context<'a> { .map_err(|err| error_context!(err, "symbol `{}` defined by node {}", symbol.name, node)) } - /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param( + /// Import a [`Term`] from a term that represents a static type or value. + fn import_term(&mut self, term_id: table::TermId) -> Result { + self.import_term_with_bound(term_id, TypeBound::Any) + } + + fn import_term_with_bound( &mut self, term_id: table::TermId, bound: TypeBound, - ) -> Result { + ) -> Result { (|| { if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { - return Ok(TypeParam::String); + return Ok(Term::StringType); } if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { - return Ok(TypeParam::max_nat()); + return Ok(Term::max_nat_type()); } if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { - return Ok(TypeParam::Bytes); + return Ok(Term::BytesType); } if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { - return Ok(TypeParam::Float); + return Ok(Term::FloatType); } if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { - return Ok(TypeParam::Type { b: bound }); + return Ok(TypeParam::RuntimeType(bound)); } - if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_STATIC - )); + if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { + return Err(error_unsupported!("`{}`", model::CORE_CONSTRAINT)); } - if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { - return Err(error_unsupported!( - "`{}` as `TypeParam`", - model::CORE_CONSTRAINT - )); + if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { + return Ok(Term::StaticType); } if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { - return Err(error_unsupported!("`{}` as `TypeParam`", model::CORE_CONST)); + return Err(error_unsupported!("`{}`", model::CORE_CONST)); } if let Some([item_type]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { // At present `hugr-model` has no way to express that the item // type of a list must be copyable. Therefore we import it as `Any`. - let param = Box::new( - self.import_type_param(item_type, TypeBound::Any) - .map_err(|err| error_context!(err, "item type of list type"))?, - ); - return Ok(TypeParam::List { param }); + let item_type = self + .import_term(item_type) + .map_err(|err| error_context!(err, "item type of list type"))?; + return Ok(TypeParam::new_list_type(item_type)); } if let Some([item_types]) = self.match_symbol(term_id, model::CORE_TUPLE_TYPE)? { // At present `hugr-model` has no way to express that the item // types of a tuple must be copyable. Therefore we import it as `Any`. - let params = (|| { + let item_types = (|| { self.import_closed_list(item_types)? .into_iter() - .map(|param| self.import_type_param(param, TypeBound::Any)) + .map(|param| self.import_term(param)) .collect::>() })() .map_err(|err| error_context!(err, "item types of tuple type"))?; - return Ok(TypeParam::Tuple { params }); - } - - match self.get_term(term_id)? { - table::Term::Wildcard => Err(error_uninferred!("wildcard")), - - table::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), - table::Term::Apply(symbol, _) => { - let name = self.get_symbol_name(*symbol)?; - Err(error_unsupported!("custom type `{}` as `TypeParam`", name)) - } - - table::Term::Tuple(_) - | table::Term::List { .. } - | table::Term::Func { .. } - | table::Term::Literal(_) => Err(error_invalid!("expected a static type")), - } - })() - .map_err(|err| error_context!(err, "term {} as `TypeParam`", term_id)) - } - - /// Import a `TypeArg` from a term that represents a static type or value. - fn import_type_arg(&mut self, term_id: table::TermId) -> Result { - (|| { - if let Some([]) = self.match_symbol(term_id, model::CORE_STR_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_STR_TYPE - )); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_NAT_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_NAT_TYPE - )); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_BYTES_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_BYTES_TYPE - )); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_FLOAT_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_FLOAT_TYPE - )); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_TYPE)? { - return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_TYPE)); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_CONSTRAINT)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_CONSTRAINT - )); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_STATIC)? { - return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_STATIC)); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_CONST)? { - return Err(error_unsupported!("`{}` as `TypeArg`", model::CORE_CONST)); - } - - if let Some([]) = self.match_symbol(term_id, model::CORE_LIST_TYPE)? { - return Err(error_unsupported!( - "`{}` as `TypeArg`", - model::CORE_LIST_TYPE - )); + return Ok(TypeParam::TupleType(item_types)); } match self.get_term(term_id)? { @@ -1351,59 +1273,53 @@ impl<'a> Context<'a> { .local_vars .get(var) .ok_or_else(|| error_invalid!("unknown variable {}", var))?; - let decl = self.import_type_param(var_info.r#type, var_info.bound)?; - Ok(TypeArg::new_var_use(var.1 as _, decl)) + let decl = self.import_term_with_bound(var_info.r#type, var_info.bound)?; + Ok(Term::new_var_use(var.1 as _, decl)) } table::Term::List { .. } => { let elems = (|| { self.import_closed_list(term_id)? .iter() - .map(|item| self.import_type_arg(*item)) + .map(|item| self.import_term(*item)) .collect::>() })() .map_err(|err| error_context!(err, "list items"))?; - Ok(TypeArg::List { elems }) + Ok(Term::List(elems)) } table::Term::Tuple { .. } => { let elems = (|| { self.import_closed_list(term_id)? .iter() - .map(|item| self.import_type_arg(*item)) + .map(|item| self.import_term(*item)) .collect::>() })() .map_err(|err| error_context!(err, "tuple items"))?; - Ok(TypeArg::Tuple { elems }) + Ok(Term::Tuple(elems)) } - table::Term::Literal(model::Literal::Str(value)) => Ok(TypeArg::String { - arg: value.to_string(), - }), - - table::Term::Literal(model::Literal::Nat(value)) => { - Ok(TypeArg::BoundedNat { n: *value }) + table::Term::Literal(model::Literal::Str(value)) => { + Ok(Term::String(value.to_string())) } - table::Term::Literal(model::Literal::Bytes(value)) => Ok(TypeArg::Bytes { - value: value.clone(), - }), - table::Term::Literal(model::Literal::Float(value)) => { - Ok(TypeArg::Float { value: *value }) - } - table::Term::Func { .. } => { - Err(error_unsupported!("function constant as `TypeArg`")) + table::Term::Literal(model::Literal::Nat(value)) => Ok(Term::BoundedNat(*value)), + + table::Term::Literal(model::Literal::Bytes(value)) => { + Ok(Term::Bytes(value.clone())) } + table::Term::Literal(model::Literal::Float(value)) => Ok(Term::Float(*value)), + table::Term::Func { .. } => Err(error_unsupported!("function constant")), table::Term::Apply { .. } => { - let ty = self.import_type(term_id)?; - Ok(TypeArg::Type { ty }) + let ty: Type = self.import_type(term_id)?; + Ok(ty.into()) } } })() - .map_err(|err| error_context!(err, "term {} as `TypeArg`", term_id)) + .map_err(|err| error_context!(err, "term {}", term_id)) } /// Import a `Type` from a term that represents a runtime type. @@ -1437,7 +1353,7 @@ impl<'a> Context<'a> { let args = args .iter() - .map(|arg| self.import_type_arg(*arg)) + .map(|arg| self.import_term(*arg)) .collect::, _>>() .map_err(|err| { error_context!(err, "type argument of custom type `{}`", name) diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index d27a4a0ad8..d84b754e93 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -860,7 +860,7 @@ pub(crate) mod test { let ex_id: ExtensionId = "my_extension".try_into().unwrap(); let typ_int = CustomType::new( "my_type", - vec![TypeArg::BoundedNat { n: 8 }], + vec![TypeArg::BoundedNat(8)], ex_id.clone(), TypeBound::Copyable, // Dummy extension reference. diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 0c5d42d9b0..157e878fcd 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -378,9 +378,7 @@ mod test { outputs: vec![usize_t(), tv1].into(), }; let cond2 = cond.substitute(&Substitution::new(&[ - TypeArg::List { - elems: vec![usize_t().into(); 3], - }, + TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]), qb_t().into(), ])); let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index f639584789..ac7629f583 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -406,11 +406,11 @@ mod test { let op = OpaqueOp::new( "res".try_into().unwrap(), "op", - vec![TypeArg::Type { ty: usize_t() }], + vec![usize_t().into()], sig.clone(), ); assert_eq!(op.name(), "OpaqueOp:res.op"); - assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); + assert_eq!(op.args(), &[usize_t().into()]); assert_eq!(op.signature().as_ref(), &sig); } diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 5db32d55ed..a7b260af6c 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -4,14 +4,14 @@ use std::num::NonZeroU64; use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; -use crate::types::TypeName; +use crate::types::{Term, TypeName}; use crate::{ Extension, extension::ExtensionId, ops::constant::CustomConst, types::{ ConstTypeError, CustomType, Type, TypeBound, - type_param::{TypeArg, TypeArgError, TypeParam}, + type_param::{TermTypeError, TypeArg, TypeParam}, }, }; use lazy_static::lazy_static; @@ -49,7 +49,7 @@ pub fn int_type(width_arg: impl Into) -> Type { lazy_static! { /// Array of valid integer types, indexed by log width of the integer. pub static ref INT_TYPES: [Type; LOG_WIDTH_BOUND as usize] = (0..LOG_WIDTH_BOUND) - .map(|i| int_type(TypeArg::BoundedNat { n: u64::from(i) })) + .map(|i| int_type(Term::from(u64::from(i)))) .collect::>() .try_into() .unwrap(); @@ -69,27 +69,25 @@ pub const LOG_WIDTH_BOUND: u8 = LOG_WIDTH_MAX + 1; /// Type parameter for the log width of the integer. #[allow(clippy::assertions_on_constants)] -pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat({ +pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat_type({ assert!(LOG_WIDTH_BOUND > 0); NonZeroU64::MIN.saturating_add(LOG_WIDTH_BOUND as u64 - 1) }); /// Get the log width of the specified type argument or error if the argument /// is invalid. -pub(super) fn get_log_width(arg: &TypeArg) -> Result { +pub(super) fn get_log_width(arg: &TypeArg) -> Result { match arg { - TypeArg::BoundedNat { n } if is_valid_log_width(*n as u8) => Ok(*n as u8), - _ => Err(TypeArgError::TypeMismatch { - arg: arg.clone(), - param: LOG_WIDTH_TYPE_PARAM, + TypeArg::BoundedNat(n) if is_valid_log_width(*n as u8) => Ok(*n as u8), + _ => Err(TermTypeError::TypeMismatch { + term: arg.clone(), + type_: LOG_WIDTH_TYPE_PARAM, }), } } const fn type_arg(log_width: u8) -> TypeArg { - TypeArg::BoundedNat { - n: log_width as u64, - } + TypeArg::BoundedNat(log_width as u64) } /// An integer (either signed or unsigned) @@ -239,13 +237,13 @@ mod test { #[test] fn test_int_widths() { - let type_arg_32 = TypeArg::BoundedNat { n: 5 }; + let type_arg_32 = TypeArg::BoundedNat(5); assert_matches!(get_log_width(&type_arg_32), Ok(5)); - let type_arg_128 = TypeArg::BoundedNat { n: 7 }; + let type_arg_128 = TypeArg::BoundedNat(7); assert_matches!( get_log_width(&type_arg_128), - Err(TypeArgError::TypeMismatch { .. }) + Err(TermTypeError::TypeMismatch { .. }) ); } diff --git a/hugr-core/src/std_extensions/arithmetic/mod.rs b/hugr-core/src/std_extensions/arithmetic/mod.rs index dc26ac4b0b..fbf3531ee7 100644 --- a/hugr-core/src/std_extensions/arithmetic/mod.rs +++ b/hugr-core/src/std_extensions/arithmetic/mod.rs @@ -20,7 +20,7 @@ mod test { for i in 0..LOG_WIDTH_BOUND { assert_eq!( INT_TYPES[i as usize], - int_type(TypeArg::BoundedNat { n: u64::from(i) }) + int_type(TypeArg::BoundedNat(u64::from(i))) ); } } diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index eb31441453..01e0ca73a6 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -96,7 +96,7 @@ lazy_static! { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( ARRAY_TYPENAME, - vec![ TypeParam::max_nat(), TypeBound::Any.into()], + vec![ TypeParam::max_nat_type(), TypeBound::Any.into()], "Fixed-length array".into(), // Default array is linear, even if the elements are copyable TypeDefBound::any(), diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs index 2a3de6d6d9..d522142a5b 100644 --- a/hugr-core/src/std_extensions/collections/array/array_clone.rs +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -51,8 +51,8 @@ impl FromStr for GenericArrayCloneDef { impl GenericArrayCloneDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let element_ty = Type::new_var_use(1, TypeBound::Copyable); let array_ty = AK::instantiate_ty(array_def, size, element_ty) .expect("Array type instantiation failed"); @@ -157,10 +157,7 @@ impl MakeExtensionOp for GenericArrayClone { } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![self.size.into(), self.elem_ty.clone().into()] } } @@ -183,7 +180,7 @@ impl HasConcrete for GenericArrayCloneDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs index 21544dfd15..bbb79336a2 100644 --- a/hugr-core/src/std_extensions/collections/array/array_conversion.rs +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -76,8 +76,8 @@ impl { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let params = vec![TypeParam::max_nat_type(), TypeBound::Any.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let element_ty = Type::new_var_use(1, TypeBound::Any); let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone()) @@ -202,10 +202,7 @@ impl MakeExtensionOp } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()] } } @@ -234,7 +231,7 @@ impl HasConcrete fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { Ok(GenericArrayConvert::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs index 67e2281f72..a97ac35eaa 100644 --- a/hugr-core/src/std_extensions/collections/array/array_discard.rs +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -51,8 +51,8 @@ impl FromStr for GenericArrayDiscardDef { impl GenericArrayDiscardDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; - let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let element_ty = Type::new_var_use(1, TypeBound::Copyable); let array_ty = AK::instantiate_ty(array_def, size, element_ty) .expect("Array type instantiation failed"); @@ -141,10 +141,7 @@ impl MakeExtensionOp for GenericArrayDiscard { } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![self.size.into(), self.elem_ty.clone().into()] } } @@ -167,7 +164,7 @@ impl HasConcrete for GenericArrayDiscardDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if ty.copyable() => { Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 915603c1da..deaac6eb58 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -16,7 +16,7 @@ use crate::extension::{ use crate::ops::{ExtensionOp, OpName}; use crate::type_row; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Term, Type, TypeBound}; use crate::utils::Never; use super::array_kind::ArrayKind; @@ -65,11 +65,11 @@ pub enum GenericArrayOpDef { } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. -const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat()]; +const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat_type()]; impl SignatureFromArgs for GenericArrayOpDef { fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { - let [TypeArg::BoundedNat { n }] = *arg_values else { + let [TypeArg::BoundedNat(n)] = *arg_values else { return Err(SignatureError::InvalidTypeArgs); }; let elem_ty_var = Type::new_var_use(0, TypeBound::Any); @@ -139,11 +139,11 @@ impl GenericArrayOpDef { // signature computed dynamically, so can rely on type definition in extension. (*self).into() } else { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let elem_ty_var = Type::new_var_use(1, TypeBound::Any); let array_ty = AK::instantiate_ty(array_def, size_var.clone(), elem_ty_var.clone()) .expect("Array type instantiation failed"); - let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; + let standard_params = vec![TypeParam::max_nat_type(), TypeBound::Any.into()]; // We can assume that the prelude has ben loaded at this point, // since it doesn't depend on the array extension. @@ -151,7 +151,7 @@ impl GenericArrayOpDef { match self { get => { - let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; + let params = vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); let copy_array_ty = AK::instantiate_ty(array_def, size_var, copy_elem_ty.clone()) @@ -282,13 +282,11 @@ impl MakeExtensionOp for GenericArrayOp { def.instantiate(ext_op.args()) } - fn type_args(&self) -> Vec { + fn type_args(&self) -> Vec { use GenericArrayOpDef::{ _phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap, unpack, }; - let ty_arg = TypeArg::Type { - ty: self.elem_ty.clone(), - }; + let ty_arg = self.elem_ty.clone().into(); match self.def { discard_empty => { debug_assert_eq!( @@ -298,7 +296,7 @@ impl MakeExtensionOp for GenericArrayOp { vec![ty_arg] } new_array | unpack | pop_left | pop_right | get | set | swap => { - vec![TypeArg::BoundedNat { n: self.size }, ty_arg] + vec![self.size.into(), ty_arg] } _phantom(_, never) => match never {}, } @@ -322,10 +320,10 @@ impl HasDef for GenericArrayOp { impl HasConcrete for GenericArrayOpDef { type Concrete = GenericArrayOp; - fn instantiate(&self, type_args: &[TypeArg]) -> Result { + fn instantiate(&self, type_args: &[Term]) -> Result { let (ty, size) = match (self, type_args) { - (GenericArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), - (_, [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]) => (ty.clone(), *n), + (GenericArrayOpDef::discard_empty, [Term::Runtime(ty)]) => (ty.clone(), 0), + (_, [Term::BoundedNat(n), Term::Runtime(ty)]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index d3302d253a..b1b84e3521 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -52,8 +52,8 @@ impl FromStr for GenericArrayRepeatDef { impl GenericArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; - let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let params = vec![TypeParam::max_nat_type(), TypeBound::Any.into()]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let t = Type::new_var_use(1, TypeBound::Any); let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); let array_ty = @@ -147,10 +147,7 @@ impl MakeExtensionOp for GenericArrayRepeat { } fn type_args(&self) -> Vec { - vec![ - TypeArg::BoundedNat { n: self.size }, - self.elem_ty.clone().into(), - ] + vec![self.size.into(), self.elem_ty.clone().into()] } } @@ -173,7 +170,7 @@ impl HasConcrete for GenericArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => { Ok(GenericArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index b29996bfe8..de39db3175 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -56,12 +56,12 @@ impl GenericArrayScanDef { fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { // array, (T1, *A -> T2, *A), *A, -> array, *A let params = vec![ - TypeParam::max_nat(), + TypeParam::max_nat_type(), TypeBound::Any.into(), TypeBound::Any.into(), - TypeParam::new_list(TypeBound::Any), + TypeParam::new_list_type(TypeBound::Any), ]; - let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let n = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let t1 = Type::new_var_use(1, TypeBound::Any); let t2 = Type::new_var_use(2, TypeBound::Any); let s = TypeRV::new_row_var_use(3, TypeBound::Any); @@ -185,12 +185,10 @@ impl MakeExtensionOp for GenericArrayScan { fn type_args(&self) -> Vec { vec![ - TypeArg::BoundedNat { n: self.size }, + self.size.into(), self.src_ty.clone().into(), self.tgt_ty.clone().into(), - TypeArg::List { - elems: self.acc_tys.clone().into_iter().map_into().collect(), - }, + TypeArg::new_list(self.acc_tys.clone().into_iter().map_into()), ] } } @@ -215,15 +213,15 @@ impl HasConcrete for GenericArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { [ - TypeArg::BoundedNat { n }, - TypeArg::Type { ty: src_ty }, - TypeArg::Type { ty: tgt_ty }, - TypeArg::List { elems: acc_tys }, + TypeArg::BoundedNat(n), + TypeArg::Runtime(src_ty), + TypeArg::Runtime(tgt_ty), + TypeArg::List(acc_tys), ] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() .map(|acc_ty| match acc_ty { - TypeArg::Type { ty } => Ok(ty.clone()), + TypeArg::Runtime(ty) => Ok(ty.clone()), _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs index 8828acd982..dfe0d0a9f2 100644 --- a/hugr-core/src/std_extensions/collections/array/array_value.rs +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -94,9 +94,7 @@ impl GenericArrayValue { // constant can only hold classic type. let ty = match typ.args() { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] - if *n as usize == self.values.len() => - { + [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] if *n as usize == self.values.len() => { ty } _ => { diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 05d05048a6..8432ae48e2 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -21,7 +21,7 @@ use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc}; use crate::ops::constant::{TryHash, ValueName, maybe_hash_values}; use crate::ops::{OpName, Value}; -use crate::types::{TypeName, TypeRowRV}; +use crate::types::{Term, TypeName, TypeRowRV}; use crate::{ Extension, extension::{ @@ -112,7 +112,7 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Type { ty }] = typ.args() else { + let [TypeArg::Runtime(ty)] = typ.args() else { return Err(error()); }; @@ -167,7 +167,7 @@ pub enum ListOp { impl ListOp { /// Type parameter used in the list types. - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any); /// Instantiate a list operation with an `element_type`. #[must_use] @@ -325,9 +325,7 @@ pub fn list_type_def() -> &'static TypeDef { /// Get the type of a list of `elem_type` as a `CustomType`. #[must_use] pub fn list_custom_type(elem_type: Type) -> CustomType { - list_type_def() - .instantiate(vec![TypeArg::Type { ty: elem_type }]) - .unwrap() + list_type_def().instantiate(vec![elem_type.into()]).unwrap() } /// Get the `Type` of a list of `elem_type`. @@ -353,7 +351,7 @@ impl MakeExtensionOp for ListOpInst { fn from_extension_op( ext_op: &ExtensionOp, ) -> Result { - let [TypeArg::Type { ty }] = ext_op.args() else { + let [Term::Runtime(ty)] = ext_op.args() else { return Err(SignatureError::InvalidTypeArgs.into()); }; let name = ext_op.unqualified_id(); @@ -367,10 +365,8 @@ impl MakeExtensionOp for ListOpInst { }) } - fn type_args(&self) -> Vec { - vec![TypeArg::Type { - ty: self.elem_type.clone(), - }] + fn type_args(&self) -> Vec { + vec![self.elem_type.clone().into()] } } @@ -413,15 +409,9 @@ mod test { fn test_list() { let list_def = list_type_def(); - let list_type = list_def - .instantiate([TypeArg::Type { ty: usize_t() }]) - .unwrap(); + let list_type = list_def.instantiate([usize_t().into()]).unwrap(); - assert!( - list_def - .instantiate([TypeArg::BoundedNat { n: 3 }]) - .is_err() - ); + assert!(list_def.instantiate([3u64.into()]).is_err()); list_def.check_custom(&list_type).unwrap(); let list_value = ListValue(vec![ConstUsize::new(3).into()], usize_t()); diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 6f3e889e68..fdaec40489 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -38,7 +38,7 @@ use crate::{ types::{ ConstTypeError, CustomCheckFailure, CustomType, PolyFuncType, Signature, Type, TypeArg, TypeBound, TypeName, - type_param::{TypeArgError, TypeParam}, + type_param::{TermTypeError, TypeParam}, }, }; @@ -309,12 +309,12 @@ impl HasConcrete for StaticArrayOpDef { match type_args { [arg] => { let elem_ty = arg - .as_type() + .as_runtime() .filter(|t| Copyable.contains(t.least_upper_bound())) .ok_or(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: Copyable.into(), - arg: arg.clone(), + TermTypeError::TypeMismatch { + type_: Copyable.into(), + term: arg.clone(), }, ))?; @@ -324,7 +324,7 @@ impl HasConcrete for StaticArrayOpDef { }) } _ => Err( - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(type_args.len(), 1)) + SignatureError::TypeArgMismatch(TermTypeError::WrongNumberArgs(type_args.len(), 1)) .into(), ), } diff --git a/hugr-core/src/std_extensions/collections/value_array.rs b/hugr-core/src/std_extensions/collections/value_array.rs index fe89824d77..33f4916a6c 100644 --- a/hugr-core/src/std_extensions/collections/value_array.rs +++ b/hugr-core/src/std_extensions/collections/value_array.rs @@ -102,7 +102,7 @@ lazy_static! { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { extension.add_type( VALUE_ARRAY_TYPENAME, - vec![ TypeParam::max_nat(), TypeBound::Any.into()], + vec![ TypeParam::max_nat_type(), TypeBound::Any.into()], "Fixed-length value array".into(), // Value arrays are copyable iff their elements are TypeDefBound::from_params(vec![1]), diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 3955c3a972..74ecf63fc1 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -89,9 +89,7 @@ impl MakeOpDef for PtrOpDef { pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("ptr"); /// Name of pointer type. pub const PTR_TYPE_ID: TypeName = TypeName::new_inline("ptr"); -const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type { - b: TypeBound::Copyable, -}]; +const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::RuntimeType(TypeBound::Copyable)]; /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); @@ -209,7 +207,7 @@ impl HasConcrete for PtrOpDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { let ty = match type_args { - [TypeArg::Type { ty }] => ty.clone(), + [TypeArg::Runtime(ty)] => ty.clone(), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 1d6fac620e..0c268d5e5e 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -15,14 +15,14 @@ use crate::extension::resolution::{ ExtensionCollectionError, WeakExtensionRegistry, collect_type_exts, }; pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure}; -use crate::types::type_param::check_type_arg; +use crate::types::type_param::check_term_type; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; pub use signature::{FuncTypeBase, FuncValueType, Signature}; use smol_str::SmolStr; -pub use type_param::TypeArg; +pub use type_param::{Term, TypeArg}; pub use type_row::{TypeRow, TypeRowRV}; pub(crate) use poly_func::PolyFuncTypeBase; @@ -490,7 +490,7 @@ impl TypeBase { /// New use (occurrence) of the type variable with specified index. /// `bound` must be exactly that with which the variable was declared - /// (i.e. as a [`TypeParam::Type`]`(bound)`), which may be narrower + /// (i.e. as a [`Term::RuntimeType`]`(bound)`), which may be narrower /// than required for the use. #[must_use] pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { @@ -575,7 +575,7 @@ impl TypeBase { TypeEnum::RowVar(rv) => rv.substitute(t), TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], TypeEnum::Variable(idx, bound) => { - let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else { + let TypeArg::Runtime(ty) = t.apply_var(*idx, &((*bound).into())) else { panic!("Variable was not a type - try validate() first") }; vec![ty.into_()] @@ -653,7 +653,7 @@ impl TypeRV { /// New use (occurrence) of the row variable with specified index. /// `bound` must match that with which the variable was declared - /// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound). + /// (i.e. as a list of runtime types of that bound). /// For use in [OpDef], not [FuncDefn], type schemes only. /// /// [OpDef]: crate::extension::OpDef @@ -740,7 +740,7 @@ impl<'a> Substitution<'a> { .0 .get(idx) .expect("Undeclared type variable - call validate() ?"); - debug_assert_eq!(check_type_arg(arg, decl), Ok(())); + debug_assert_eq!(check_term_type(arg, decl), Ok(())); arg.clone() } @@ -749,14 +749,14 @@ impl<'a> Substitution<'a> { .0 .get(idx) .expect("Undeclared type variable - call validate() ?"); - debug_assert!(check_type_arg(arg, &TypeParam::new_list(bound)).is_ok()); + debug_assert!(check_term_type(arg, &TypeParam::new_list_type(bound)).is_ok()); match arg { - TypeArg::List { elems } => elems + TypeArg::List(elems) => elems .iter() .map(|ta| { match ta { - TypeArg::Type { ty } => return ty.clone().into(), - TypeArg::Variable { v } => { + Term::Runtime(ty) => return ty.clone().into(), + Term::Variable(v) => { if let Some(b) = v.bound_if_row_var() { return TypeRV::new_row_var_use(v.index(), b); } @@ -766,7 +766,7 @@ impl<'a> Substitution<'a> { panic!("Not a list of types - call validate() ?") }) .collect(), - TypeArg::Type { ty } if matches!(ty.0, TypeEnum::RowVar(_)) => { + Term::Runtime(ty) if matches!(ty.0, TypeEnum::RowVar(_)) => { // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type vec![ty.clone().into()] } @@ -781,7 +781,7 @@ impl<'a> Substitution<'a> { /// and applies to arbitrary extension types rather than type variables. pub trait TypeTransformer { /// Error returned when a [`CustomType`] cannot be transformed, or a type - /// containing it (e.g. if changing a [`TypeArg::Type`] from copyable to + /// containing it (e.g. if changing a runtime type from copyable to /// linear invalidates a parameterized type). type Err: std::error::Error + From; @@ -857,7 +857,7 @@ pub(crate) mod test { use crate::extension::prelude::{option_type, qb_t, usize_t}; use crate::std_extensions::collections::array::{array_type, array_type_parametric}; use crate::std_extensions::collections::list::list_type; - use crate::types::type_param::TypeArgError; + use crate::types::type_param::TermTypeError; use crate::{Extension, hugr::IdentList, type_row}; #[test] @@ -977,7 +977,7 @@ pub(crate) mod test { |t| array_type(10, t), |t| { array_type_parametric( - TypeArg::new_var_use(0, TypeParam::bounded_nat(3.try_into().unwrap())), + TypeArg::new_var_use(0, TypeParam::bounded_nat_type(3.try_into().unwrap())), t, ) .unwrap() @@ -1001,7 +1001,7 @@ pub(crate) mod test { .unwrap(); e.add_type( COLN, - vec![TypeParam::new_list(TypeBound::Copyable)], + vec![TypeParam::new_list_type(TypeBound::Copyable)], String::new(), TypeDefBound::copyable(), w, @@ -1020,31 +1020,27 @@ pub(crate) mod test { let coln = e.get_type(&COLN).unwrap(); let c_of_cpy = coln - .instantiate([TypeArg::List { - elems: vec![Type::from(cpy.clone()).into()], - }]) + .instantiate([Term::new_list([Type::from(cpy.clone()).into()])]) .unwrap(); let mut t = Type::new_extension(c_of_cpy.clone()); assert_eq!( t.transform(&cpy_to_qb), - Err(SignatureError::from(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: qb_t().into() + Err(SignatureError::from(TermTypeError::TypeMismatch { + type_: TypeBound::Copyable.into(), + term: qb_t().into() })) ); let mut t = Type::new_extension( - coln.instantiate([TypeArg::List { - elems: vec![mk_opt(Type::from(cpy.clone())).into()], - }]) - .unwrap(), + coln.instantiate([Term::new_list([mk_opt(Type::from(cpy.clone())).into()])]) + .unwrap(), ); assert_eq!( t.transform(&cpy_to_qb), - Err(SignatureError::from(TypeArgError::TypeMismatch { - param: TypeBound::Copyable.into(), - arg: mk_opt(qb_t()).into() + Err(SignatureError::from(TermTypeError::TypeMismatch { + type_: TypeBound::Copyable.into(), + term: mk_opt(qb_t()).into() })) ); @@ -1054,19 +1050,15 @@ pub(crate) mod test { (ct == &c_of_cpy).then_some(usize_t()) }); let mut t = Type::new_extension( - coln.instantiate([TypeArg::List { - elems: vec![Type::from(c_of_cpy.clone()).into(); 2], - }]) - .unwrap(), + coln.instantiate([Term::new_list(vec![Type::from(c_of_cpy.clone()).into(); 2])]) + .unwrap(), ); assert_eq!(t.transform(&cpy_to_qb2), Ok(true)); assert_eq!( t, Type::new_extension( - coln.instantiate([TypeArg::List { - elems: vec![usize_t().into(); 2] - }]) - .unwrap() + coln.instantiate([Term::new_list([usize_t().into(), usize_t().into()])]) + .unwrap() ) ); } diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 20f48907b4..70d054018d 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -13,7 +13,7 @@ use { }; use super::Substitution; -use super::type_param::{TypeArg, TypeParam, check_type_args}; +use super::type_param::{TypeArg, TypeParam, check_term_types}; use super::{MaybeRV, NoRV, RowVariable, signature::FuncTypeBase}; /// A polymorphic type scheme, i.e. of a [`FuncDecl`], [`FuncDefn`] or [`OpDef`]. @@ -122,13 +122,19 @@ impl PolyFuncTypeBase { pub fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. - check_type_args(args, &self.params)?; + check_term_types(args, &self.params)?; Ok(self.body.substitute(&Substitution(args))) } /// Validates this instance, checking that the types in the body are /// wellformed with respect to the registry, and the type variables declared. pub fn validate(&self) -> Result<(), SignatureError> { + for (i, p) in self.params.iter().enumerate() { + // This checks that variables have correct cached_decls, + // allowing them to refer to earlier-declared parameters only: + p.validate(&self.params[0..i])?; + p.validate_param()?; + } self.body.validate(&self.params) } @@ -166,9 +172,9 @@ pub(crate) mod test { use crate::std_extensions::collections::array::{self, array_type_parametric}; use crate::std_extensions::collections::list; use crate::types::signature::FuncTypeBase; - use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; + use crate::types::type_param::{TermTypeError, TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, MaybeRV, Signature, Type, TypeBound, TypeName, TypeRV, + CustomType, FuncValueType, MaybeRV, Signature, Term, Type, TypeBound, TypeName, TypeRV, }; use super::PolyFuncTypeBase; @@ -199,14 +205,12 @@ pub(crate) mod test { Signature::new(vec![list_of_var], vec![usize_t()]), )?; - let t = list_len.instantiate(&[TypeArg::Type { ty: usize_t() }])?; + let t = list_len.instantiate(&[usize_t().into()])?; assert_eq!( t, Signature::new( vec![Type::new_extension( - list_def - .instantiate([TypeArg::Type { ty: usize_t() }]) - .unwrap() + list_def.instantiate([usize_t().into()]).unwrap() )], vec![usize_t()] ) @@ -217,9 +221,9 @@ pub(crate) mod test { #[test] fn test_mismatched_args() -> Result<(), SignatureError> { - let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); + let size_var = TypeArg::new_var_use(0, TypeParam::max_nat_type()); let ty_var = TypeArg::new_var_use(1, TypeBound::Any.into()); - let type_params = [TypeParam::max_nat(), TypeBound::Any.into()]; + let type_params = [TypeParam::max_nat_type(), TypeBound::Any.into()]; // Valid schema... let good_array = array_type_parametric(size_var.clone(), ty_var.clone())?; @@ -227,29 +231,23 @@ pub(crate) mod test { PolyFuncTypeBase::new_validated(type_params.clone(), Signature::new_endo(good_array))?; // Sanity check (good args) - good_ts.instantiate(&[ - TypeArg::BoundedNat { n: 5 }, - TypeArg::Type { ty: usize_t() }, - ])?; - - let wrong_args = good_ts.instantiate(&[ - TypeArg::Type { ty: usize_t() }, - TypeArg::BoundedNat { n: 5 }, - ]); + good_ts.instantiate(&[5u64.into(), usize_t().into()])?; + + let wrong_args = good_ts.instantiate(&[usize_t().into(), 5u64.into()]); assert_eq!( wrong_args, Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: type_params[0].clone(), - arg: TypeArg::Type { ty: usize_t() } + TermTypeError::TypeMismatch { + type_: type_params[0].clone(), + term: usize_t().into(), } )) ); // (Try to) make a schema with the args in the wrong order - let arg_err = SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { - param: type_params[0].clone(), - arg: ty_var.clone(), + let arg_err = SignatureError::TypeArgMismatch(TermTypeError::TypeMismatch { + type_: type_params[0].clone(), + term: ty_var.clone(), }); assert_eq!( array_type_parametric(ty_var.clone(), size_var.clone()), @@ -277,13 +275,9 @@ pub(crate) mod test { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ - TypeParam::List { - param: Box::new(TypeParam::max_nat()), - }, - TypeParam::String, - TypeParam::Tuple { - params: vec![TypeBound::Any.into(), TypeParam::max_nat()], - }, + Term::new_list_type(Term::max_nat_type()), + Term::StringType, + Term::TupleType(vec![TypeBound::Any.into(), Term::max_nat_type()]), ] { let invalid_ts = PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone()); assert_eq!( @@ -348,9 +342,9 @@ pub(crate) mod test { assert_eq!( make_scheme(decl.clone()).err(), Some(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: bound.clone(), - arg: TypeArg::new_var_use(0, decl.clone()) + TermTypeError::TypeMismatch { + type_: bound.clone(), + term: TypeArg::new_var_use(0, decl.clone()) } )) ); @@ -366,35 +360,30 @@ pub(crate) mod test { &[TypeBound::Any.into()], )?; - let list_of_tys = |b: TypeBound| TypeParam::List { - param: Box::new(b.into()), - }; decl_accepts_rejects_var( - list_of_tys(TypeBound::Copyable), - &[list_of_tys(TypeBound::Copyable)], - &[list_of_tys(TypeBound::Any)], + Term::new_list_type(TypeBound::Copyable), + &[Term::new_list_type(TypeBound::Copyable)], + &[Term::new_list_type(TypeBound::Any)], )?; decl_accepts_rejects_var( - TypeParam::max_nat(), - &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], + TypeParam::max_nat_type(), + &[TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap())], &[], )?; decl_accepts_rejects_var( - TypeParam::bounded_nat(NonZeroU64::new(10).unwrap()), - &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], - &[TypeParam::max_nat()], + TypeParam::bounded_nat_type(NonZeroU64::new(10).unwrap()), + &[TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap())], + &[TypeParam::max_nat_type()], )?; Ok(()) } - const TP_ANY: TypeParam = TypeParam::Type { b: TypeBound::Any }; + const TP_ANY: TypeParam = TypeParam::RuntimeType(TypeBound::Any); #[test] fn row_variables_bad_schema() { // Mismatched TypeBound (Copyable vs Any) - let decl = TypeParam::List { - param: Box::new(TP_ANY), - }; + let decl = Term::new_list_type(TP_ANY); let e = PolyFuncTypeBase::new_validated( [decl.clone()], FuncValueType::new( @@ -405,7 +394,7 @@ pub(crate) mod test { .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { assert_eq!(actual, decl); - assert_eq!(cached, TypeParam::List {param: Box::new(TypeParam::Type {b: TypeBound::Copyable})}); + assert_eq!(cached, TypeParam::new_list_type(TypeBound::Copyable)); }); // Declared as row variable, used as type variable let e = PolyFuncTypeBase::new_validated( @@ -423,7 +412,7 @@ pub(crate) mod test { fn row_variables() { let rty = TypeRV::new_row_var_use(0, TypeBound::Any); let pf = PolyFuncTypeBase::new_validated( - [TypeParam::new_list(TP_ANY)], + [TypeParam::new_list_type(TP_ANY)], FuncValueType::new( vec![usize_t().into(), rty.clone()], vec![TypeRV::new_tuple(rty)], @@ -434,14 +423,11 @@ pub(crate) mod test { fn seq2() -> Vec { vec![usize_t().into(), bool_t().into()] } - pf.instantiate(&[TypeArg::Type { ty: usize_t() }]) + pf.instantiate(&[usize_t().into()]).unwrap_err(); + pf.instantiate(&[Term::new_list([usize_t().into(), Term::new_list(seq2())])]) .unwrap_err(); - pf.instantiate(&[TypeArg::List { - elems: vec![usize_t().into(), TypeArg::List { elems: seq2() }], - }]) - .unwrap_err(); - let t2 = pf.instantiate(&[TypeArg::List { elems: seq2() }]).unwrap(); + let t2 = pf.instantiate(&[Term::new_list(seq2())]).unwrap(); assert_eq!( t2, Signature::new( @@ -458,20 +444,18 @@ pub(crate) mod test { TypeBound::Copyable, ))); let pf = PolyFuncTypeBase::new_validated( - [TypeParam::List { - param: Box::new(TypeParam::Type { - b: TypeBound::Copyable, - }), - }], + [Term::new_list_type(TypeBound::Copyable)], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), ) .unwrap(); let inner3 = Type::new_function(Signature::new_endo(vec![usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate(&[TypeArg::List { - elems: vec![usize_t().into(), bool_t().into(), usize_t().into()], - }]) + .instantiate(&[Term::new_list([ + usize_t().into(), + bool_t().into(), + usize_t().into(), + ])]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/row_var.rs b/hugr-core/src/types/row_var.rs index 106870003b..086ab7b076 100644 --- a/hugr-core/src/types/row_var.rs +++ b/hugr-core/src/types/row_var.rs @@ -6,7 +6,7 @@ use crate::extension::SignatureError; #[cfg(test)] use proptest::prelude::{BoxedStrategy, Strategy, any}; -/// Describes a row variable - a type variable bound with a [`TypeParam::List`] of [`TypeParam::Type`] +/// Describes a row variable - a type variable bound with a list of runtime types /// of the specified bound (checked in validation) // The serde derives here are not used except as markers // so that other types containing this can also #derive-serde the same way. @@ -70,7 +70,7 @@ impl MaybeRV for RowVariable { } fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { - check_typevar_decl(var_decls, self.0, &TypeParam::new_list(self.1)) + check_typevar_decl(var_decls, self.0, &TypeParam::new_list_type(self.1)) } #[allow(private_interfaces)] diff --git a/hugr-core/src/types/serialize.rs b/hugr-core/src/types/serialize.rs index 198c0c1eda..f36f0d3081 100644 --- a/hugr-core/src/types/serialize.rs +++ b/hugr-core/src/types/serialize.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use ordered_float::OrderedFloat; + use super::{FuncValueType, MaybeRV, RowVariable, SumType, TypeBase, TypeBound, TypeEnum}; use super::custom::CustomType; @@ -5,6 +9,8 @@ use super::custom::CustomType; use crate::extension::SignatureError; use crate::extension::prelude::{qb_t, usize_t}; use crate::ops::AliasDecl; +use crate::types::type_param::{TermVar, UpperBound}; +use crate::types::{Term, Type}; #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] @@ -60,3 +66,129 @@ impl TryFrom for TypeBase { }) } } + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +#[non_exhaustive] +#[serde(tag = "tp")] +pub(super) enum TypeParamSer { + Type { b: TypeBound }, + BoundedNat { bound: UpperBound }, + String, + Bytes, + Float, + StaticType, + List { param: Box }, + Tuple { params: Vec }, +} + +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +#[non_exhaustive] +#[serde(tag = "tya")] +pub(super) enum TypeArgSer { + Type { + ty: Type, + }, + BoundedNat { + n: u64, + }, + String { + arg: String, + }, + Bytes { + #[serde(with = "base64")] + value: Arc<[u8]>, + }, + Float { + value: OrderedFloat, + }, + List { + elems: Vec, + }, + Tuple { + elems: Vec, + }, + Variable { + #[serde(flatten)] + v: TermVar, + }, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub(super) enum TermSer { + TypeArg(TypeArgSer), + TypeParam(TypeParamSer), +} + +impl From for TermSer { + fn from(value: Term) -> Self { + match value { + Term::RuntimeType(b) => TermSer::TypeParam(TypeParamSer::Type { b }), + Term::StaticType => TermSer::TypeParam(TypeParamSer::StaticType), + Term::BoundedNatType(bound) => TermSer::TypeParam(TypeParamSer::BoundedNat { bound }), + Term::StringType => TermSer::TypeParam(TypeParamSer::String), + Term::BytesType => TermSer::TypeParam(TypeParamSer::Bytes), + Term::FloatType => TermSer::TypeParam(TypeParamSer::Float), + Term::ListType(param) => TermSer::TypeParam(TypeParamSer::List { param }), + Term::TupleType(params) => TermSer::TypeParam(TypeParamSer::Tuple { params }), + Term::Runtime(ty) => TermSer::TypeArg(TypeArgSer::Type { ty }), + Term::BoundedNat(n) => TermSer::TypeArg(TypeArgSer::BoundedNat { n }), + Term::String(arg) => TermSer::TypeArg(TypeArgSer::String { arg }), + Term::Bytes(value) => TermSer::TypeArg(TypeArgSer::Bytes { value }), + Term::Float(value) => TermSer::TypeArg(TypeArgSer::Float { value }), + Term::List(elems) => TermSer::TypeArg(TypeArgSer::List { elems }), + Term::Tuple(elems) => TermSer::TypeArg(TypeArgSer::Tuple { elems }), + Term::Variable(v) => TermSer::TypeArg(TypeArgSer::Variable { v }), + } + } +} + +impl From for Term { + fn from(value: TermSer) -> Self { + match value { + TermSer::TypeParam(param) => match param { + TypeParamSer::Type { b } => Term::RuntimeType(b), + TypeParamSer::StaticType => Term::StaticType, + TypeParamSer::BoundedNat { bound } => Term::BoundedNatType(bound), + TypeParamSer::String => Term::StringType, + TypeParamSer::Bytes => Term::BytesType, + TypeParamSer::Float => Term::FloatType, + TypeParamSer::List { param } => Term::ListType(param), + TypeParamSer::Tuple { params } => Term::TupleType(params), + }, + TermSer::TypeArg(arg) => match arg { + TypeArgSer::Type { ty } => Term::Runtime(ty), + TypeArgSer::BoundedNat { n } => Term::BoundedNat(n), + TypeArgSer::String { arg } => Term::String(arg), + TypeArgSer::Bytes { value } => Term::Bytes(value), + TypeArgSer::Float { value } => Term::Float(value), + TypeArgSer::List { elems } => Term::List(elems), + TypeArgSer::Tuple { elems } => Term::Tuple(elems), + TypeArgSer::Variable { v } => Term::Variable(v), + }, + } + } +} + +/// Helper for to serialize and deserialize the byte string in [`TypeArg::Bytes`] via base64. +mod base64 { + use std::sync::Arc; + + use base64::Engine as _; + use base64::prelude::BASE64_STANDARD; + use serde::{Deserialize, Serialize}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(v: &Arc<[u8]>, s: S) -> Result { + let base64 = BASE64_STANDARD.encode(v); + base64.serialize(s) + } + + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { + let base64 = String::deserialize(d)?; + BASE64_STANDARD + .decode(base64.as_bytes()) + .map(|v| v.into()) + .map_err(serde::de::Error::custom) + } +} diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 75e60e2de5..df40f1a173 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -50,312 +50,333 @@ impl UpperBound { } } -/// A *kind* of [`TypeArg`]. Thus, a parameter declared by a [`PolyFuncType`] or [`PolyFuncTypeRV`], -/// specifying a value that must be provided statically in order to instantiate it. -/// -/// [`PolyFuncType`]: super::PolyFuncType -/// [`PolyFuncTypeRV`]: super::PolyFuncTypeRV +/// A [`Term`] that is a static argument to an operation or constructor. +pub type TypeArg = Term; + +/// A [`Term`] that is the static type of an operation or constructor parameter. +pub type TypeParam = Term; + +/// A term in the language of static parameters in HUGR. #[derive( Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] #[non_exhaustive] -#[serde(tag = "tp")] -pub enum TypeParam { - /// Argument is a [`TypeArg::Type`]. - #[display("Type{}", match b { +#[serde( + from = "crate::types::serialize::TermSer", + into = "crate::types::serialize::TermSer" +)] +pub enum Term { + /// The type of runtime types. + #[display("Type{}", match _0 { TypeBound::Any => String::new(), - _ => format!("[{b}]") + _ => format!("[{_0}]") })] - Type { - /// Bound for the type parameter. - b: TypeBound, - }, - /// Argument is a [`TypeArg::BoundedNat`] that is less than the upper bound. - #[display("{}", match bound.value() { + RuntimeType(TypeBound), + /// The type of static data. + StaticType, + /// The type of static natural numbers up to a given bound. + #[display("{}", match _0.value() { Some(v) => format!("BoundedNat[{v}]"), None => "Nat".to_string() })] - BoundedNat { - /// Upper bound for the Nat parameter. - bound: UpperBound, - }, - /// Argument is a [`TypeArg::String`]. - String, - /// Argument is a [`TypeArg::Bytes`]. - Bytes, - /// Argument is a [`TypeArg::Float`]. - Float, - /// Argument is a [`TypeArg::List`]. A list of indeterminate size containing - /// parameters all of the (same) specified element type. - #[display("List[{param}]")] - List { - /// The [`TypeParam`] describing each element of the list. - param: Box, - }, - /// Argument is a [`TypeArg::Tuple`]. A tuple of parameters. - #[display("Tuple[{}]", params.iter().map(std::string::ToString::to_string).join(", "))] - Tuple { - /// The [`TypeParam`]s contained in the tuple. - params: Vec, - }, + BoundedNatType(UpperBound), + /// The type of static strings. See [`Term::String`]. + StringType, + /// The type of static byte strings. See [`Term::Bytes`]. + BytesType, + /// The type of static floating point numbers. See [`Term::Float`]. + FloatType, + /// The type of static lists of indeterminate size containing terms of the + /// specified static type. + #[display("ListType[{_0}]")] + ListType(Box), + /// The type of static tuples. + #[display("TupleType[{}]", _0.iter().map(std::string::ToString::to_string).join(", "))] + TupleType(Vec), + /// A runtime type as a term. Instance of [`Term::RuntimeType`]. + #[display("{_0}")] + Runtime(Type), + /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`]. + #[display("{_0}")] + BoundedNat(u64), + /// UTF-8 encoded string literal. Instance of [`Term::StringType`]. + #[display("\"{_0}\"")] + String(String), + /// Byte string literal. Instance of [`Term::BytesType`]. + #[display("bytes")] + Bytes(Arc<[u8]>), + /// A 64-bit floating point number. Instance of [`Term::FloatType`]. + #[display("{}", _0.into_inner())] + Float(OrderedFloat), + /// A list of static terms. Instance of [`Term::ListType`]. + #[display("[{}]", { + use itertools::Itertools as _; + _0.iter().map(|t|t.to_string()).join(",") + })] + List(Vec), + /// A tuple of static terms. Instance of [`Term::TupleType`]. + #[display("({})", { + use itertools::Itertools as _; + _0.iter().map(std::string::ToString::to_string).join(",") + })] + Tuple(Vec), + /// Variable (used in type schemes or inside polymorphic functions), + /// but not a runtime type (not even a row variable i.e. list of runtime types) + /// - see [`Term::new_var_use`] + #[display("{_0}")] + Variable(TermVar), } -impl TypeParam { - /// [`TypeParam::BoundedNat`] with the maximum bound (`u64::MAX` + 1) +impl Term { + /// Creates a [`Term::BoundedNatType`] with the maximum bound (`u64::MAX` + 1). #[must_use] - pub const fn max_nat() -> Self { - Self::BoundedNat { - bound: UpperBound(None), - } + pub const fn max_nat_type() -> Self { + Self::BoundedNatType(UpperBound(None)) } - /// [`TypeParam::BoundedNat`] with the stated upper bound (non-exclusive) + /// Creates a [`Term::BoundedNatType`] with the stated upper bound (non-exclusive). #[must_use] - pub const fn bounded_nat(upper_bound: NonZeroU64) -> Self { - Self::BoundedNat { - bound: UpperBound(Some(upper_bound)), - } + pub const fn bounded_nat_type(upper_bound: NonZeroU64) -> Self { + Self::BoundedNatType(UpperBound(Some(upper_bound))) } - /// Make a new `TypeParam::List` (an arbitrary-length homogeneous list) - pub fn new_list(elem: impl Into) -> Self { - Self::List { - param: Box::new(elem.into()), - } + /// Creates a new [`Term::List`] given a sequence of its items. + pub fn new_list(items: impl IntoIterator) -> Self { + Self::List(items.into_iter().collect()) } - fn contains(&self, other: &TypeParam) -> bool { + /// Creates a new [`Term::ListType`] given the type of its elements. + pub fn new_list_type(elem: impl Into) -> Self { + Self::ListType(Box::new(elem.into())) + } + + /// Creates a new [`Term::TupleType`] given the types of its elements. + pub fn new_tuple_type(item_types: impl IntoIterator) -> Self { + Self::TupleType(item_types.into_iter().collect()) + } + + /// Checks if this term is a supertype of another. + fn is_supertype(&self, other: &Term) -> bool { match (self, other) { - (TypeParam::Type { b: b1 }, TypeParam::Type { b: b2 }) => b1.contains(*b2), - (TypeParam::BoundedNat { bound: b1 }, TypeParam::BoundedNat { bound: b2 }) => { - b1.contains(b2) + (Term::RuntimeType(b1), Term::RuntimeType(b2)) => b1.contains(*b2), + (Term::BoundedNatType(b1), Term::BoundedNatType(b2)) => b1.contains(b2), + (Term::StringType, Term::StringType) => true, + (Term::StaticType, Term::StaticType) => true, + (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2), + (Term::TupleType(es1), Term::TupleType(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2)) } - (TypeParam::String, TypeParam::String) => true, - (TypeParam::List { param: e1 }, TypeParam::List { param: e2 }) => e1.contains(e2), - (TypeParam::Tuple { params: es1 }, TypeParam::Tuple { params: es2 }) => { - es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) + (Term::BytesType, Term::BytesType) => true, + (Term::FloatType, Term::FloatType) => true, + (Term::Variable(v1), Term::Variable(v2)) => v1 == v2 && cached_is_static(v1), + ( + Term::Runtime(_) + | Term::BoundedNat(_) + | Term::String(_) + | Term::Bytes(_) + | Term::Float(_) + | Term::List(_) + | Term::Tuple(_), + _, + ) => { + // This is not a type at all, so it's not a supertype of anything. + false } _ => false, } } } -impl From for TypeParam { - fn from(bound: TypeBound) -> Self { - Self::Type { b: bound } +fn cached_is_static(tv: &TermVar) -> bool { + match &*tv.cached_decl { + Term::Variable(tv) => cached_is_static(&*tv), + Term::StaticType => true, + _ => false, } } -impl From for TypeParam { - fn from(bound: UpperBound) -> Self { - Self::BoundedNat { bound } +impl From for Term { + fn from(bound: TypeBound) -> Self { + Self::RuntimeType(bound) } } -/// A statically-known argument value to an operation. -#[derive( - Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, -)] -#[non_exhaustive] -#[serde(tag = "tya")] -pub enum TypeArg { - /// Where the (Type/Op)Def declares that an argument is a [`TypeParam::Type`] - #[display("{ty}")] - Type { - /// The concrete type for the parameter. - ty: Type, - }, - /// Instance of [`TypeParam::BoundedNat`]. 64-bit unsigned integer. - #[display("{n}")] - BoundedNat { - /// The integer value for the parameter. - n: u64, - }, - ///Instance of [`TypeParam::String`]. UTF-8 encoded string argument. - #[display("\"{arg}\"")] - String { - /// The string value for the parameter. - arg: String, - }, - /// Instance of [`TypeParam::Bytes`]. Byte string. - #[display("bytes")] - Bytes { - /// The value of the bytes parameter. - #[serde(with = "base64")] - value: Arc<[u8]>, - }, - /// Instance of [`TypeParam::Float`]. 64-bit floating point number. - #[display("{}", value.into_inner())] - Float { - /// The value of the float parameter. - value: OrderedFloat, - }, - /// Instance of [`TypeParam::List`] defined by a sequence of elements of the same type. - #[display("[{}]", { - use itertools::Itertools as _; - elems.iter().map(|t|t.to_string()).join(",") - })] - List { - /// List of elements - elems: Vec, - }, - /// Instance of [`TypeParam::Tuple`] defined by a sequence of elements of varying type. - #[display("({})", { - use itertools::Itertools as _; - elems.iter().map(std::string::ToString::to_string).join(",") - })] - Tuple { - /// List of elements - elems: Vec, - }, - /// Variable (used in type schemes or inside polymorphic functions), - /// but not a [`TypeArg::Type`] (not even a row variable i.e. [`TypeParam::List`] of type) - /// - see [`TypeArg::new_var_use`] - #[display("{v}")] - Variable { - #[allow(missing_docs)] - #[serde(flatten)] - v: TypeArgVariable, - }, +impl From for Term { + fn from(bound: UpperBound) -> Self { + Self::BoundedNatType(bound) + } } -impl From> for TypeArg { +impl From> for Term { fn from(value: TypeBase) -> Self { match value.try_into_type() { - Ok(ty) => TypeArg::Type { ty }, - Err(RowVariable(idx, bound)) => TypeArg::new_var_use(idx, TypeParam::new_list(bound)), + Ok(ty) => Term::Runtime(ty), + Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)), } } } -impl From for TypeArg { +impl From for Term { fn from(n: u64) -> Self { - Self::BoundedNat { n } + Self::BoundedNat(n) } } -impl From for TypeArg { +impl From for Term { fn from(arg: String) -> Self { - TypeArg::String { arg } + Term::String(arg) } } -impl From<&str> for TypeArg { +impl From<&str> for Term { fn from(arg: &str) -> Self { - TypeArg::String { - arg: arg.to_string(), - } + Term::String(arg.to_string()) } } -impl From> for TypeArg { - fn from(elems: Vec) -> Self { - Self::List { elems } +impl From> for Term { + fn from(elems: Vec) -> Self { + Self::new_list(elems) } } -/// Variable in a `TypeArg`, that is not a single [`TypeArg::Type`] (i.e. not a [`Type::new_var_use`] +/// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`] /// - it might be a [`Type::new_row_var_use`]). #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, )] #[display("#{idx}")] -pub struct TypeArgVariable { +pub struct TermVar { idx: usize, - cached_decl: TypeParam, + cached_decl: Box, } -impl TypeArg { - /// [`Type::UNIT`] as a [`TypeArg::Type`] - pub const UNIT: Self = Self::Type { ty: Type::UNIT }; +impl Term { + /// [`Type::UNIT`] as a [`Term::Runtime`] + pub const UNIT: Self = Self::Runtime(Type::UNIT); /// Makes a `TypeArg` representing a use (occurrence) of the type variable /// with the specified index. /// `decl` must be exactly that with which the variable was declared. #[must_use] - pub fn new_var_use(idx: usize, decl: TypeParam) -> Self { + pub fn new_var_use(idx: usize, decl: Term) -> Self { match decl { // Note a TypeParam::List of TypeParam::Type *cannot* be represented // as a TypeArg::Type because the latter stores a Type i.e. only a single type, // not a RowVariable. - TypeParam::Type { b } => Type::new_var_use(idx, b).into(), - _ => TypeArg::Variable { - v: TypeArgVariable { - idx, - cached_decl: decl, - }, - }, + Term::RuntimeType(b) => Type::new_var_use(idx, b).into(), + _ => Term::Variable(TermVar { + idx, + cached_decl: Box::new(decl), + }), } } - /// Returns an integer if the `TypeArg` is an instance of `BoundedNat`. + /// Returns an integer if the [`Term`] is a natural number literal. #[must_use] pub fn as_nat(&self) -> Option { match self { - TypeArg::BoundedNat { n } => Some(*n), + TypeArg::BoundedNat(n) => Some(*n), _ => None, } } - /// Returns a type if the `TypeArg` is an instance of Type. + /// Returns a [`Type`] if the [`Term`] is a runtime type. #[must_use] - pub fn as_type(&self) -> Option> { + pub fn as_runtime(&self) -> Option> { match self { - TypeArg::Type { ty } => Some(ty.clone()), + TypeArg::Runtime(ty) => Some(ty.clone()), _ => None, } } - /// Returns a string if the `TypeArg` is an instance of String. + /// Returns a string if the [`Term`] is a string literal. #[must_use] pub fn as_string(&self) -> Option { match self { - TypeArg::String { arg } => Some(arg.clone()), + TypeArg::String(arg) => Some(arg.clone()), _ => None, } } + /// Check that this is a valid bound on/type for a parameter. + /// Assumes [TermVar::cached_decl] and that in [TypeEnum::Variable] are correct + /// (call [Self::validate] first to confirm). + pub(crate) fn validate_param(&self) -> Result<(), SignatureError> { + match self { + Term::RuntimeType(_) + | Term::StaticType + | Term::BoundedNatType(_) + | Term::StringType + | Term::BytesType + | Term::FloatType => Ok(()), + Term::ListType(term) => term.validate_param(), + Term::TupleType(terms) => terms.iter().try_for_each(Term::validate_param), + // Variables are allowed as long as they could be a static type; + // since StaticType is itself a StaticType, we must loop through chains + // like `(param &b &a) (param ?c ?b) ...` arbitrarily: these could be + // legal if enough of the first params are instantiated with `StaticType` + Term::Variable(tv) => { + if cached_is_static(tv) { + Ok(()) + } else { + Err(SignatureError::InvalidTypeParam(self.clone())) + } + } + // The remainder are not static types + Term::Runtime(_) + | Term::BoundedNat(_) + | Term::String(_) + | Term::Bytes(_) + | Term::Float(_) + | Term::List(_) + | Term::Tuple(_) => Err(SignatureError::InvalidTypeParam(self.clone())), + } + } + /// Much as [`Type::validate`], also checks that the type of any [`TypeArg::Opaque`] /// is valid and closed. pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { match self { - TypeArg::Type { ty } => ty.validate(var_decls), - TypeArg::List { elems } => { + Term::Runtime(ty) => ty.validate(var_decls), + Term::List(elems) => { // TODO: Full validation would check that the type of the elements agrees elems.iter().try_for_each(|a| a.validate(var_decls)) } - TypeArg::Tuple { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Float { .. } - | TypeArg::Bytes { .. } => Ok(()), - TypeArg::Variable { - v: TypeArgVariable { idx, cached_decl }, - } => { + Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)), + Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()), + Term::Variable(TermVar { idx, cached_decl }) => { assert!( - !matches!(cached_decl, TypeParam::Type { .. }), + !matches!(&**cached_decl, TypeParam::RuntimeType { .. }), "Malformed TypeArg::Variable {cached_decl} - should be inconstructible" ); check_typevar_decl(var_decls, *idx, cached_decl) } + Term::RuntimeType(_) => Ok(()), + Term::BoundedNatType { .. } => Ok(()), + Term::StringType => Ok(()), + Term::BytesType => Ok(()), + Term::FloatType => Ok(()), + Term::ListType(item_type) => item_type.validate(var_decls), + Term::TupleType(params) => params.iter().try_for_each(|p| p.validate(var_decls)), + Term::StaticType => Ok(()), } } pub(crate) fn substitute(&self, t: &Substitution) -> Self { match self { - TypeArg::Type { ty } => { - // RowVariables are represented as TypeArg::Variable + Term::Runtime(ty) => { + // RowVariables are represented as Term::Variable ty.substitute1(t).into() } - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Bytes { .. } - | TypeArg::Float { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's - TypeArg::List { elems } => { + Term::BoundedNat(_) | Term::String { .. } | Term::Bytes(_) | Term::Float(_) => { + self.clone() + } + Term::List(elems) => { let mut are_types = elems.iter().map(|ta| match ta { - TypeArg::Type { .. } => true, - TypeArg::Variable { v } => v.bound_if_row_var().is_some(), + Term::Runtime { .. } => true, + Term::Variable(v) => v.bound_if_row_var().is_some(), _ => false, }); let elems = match are_types.next() { @@ -365,8 +386,8 @@ impl TypeArg { elems .iter() .flat_map(|ta| match ta.substitute(t) { - ty @ TypeArg::Type { .. } => vec![ty], - TypeArg::List { elems } => elems, + ty @ Term::Runtime { .. } => vec![ty], + Term::List(elems) => elems, _ => panic!("Expected Type or row of Types"), }) .collect() @@ -376,34 +397,50 @@ impl TypeArg { elems.iter().map(|ta| ta.substitute(t)).collect() } }; - TypeArg::List { elems } + Term::List(elems) + } + Term::Tuple(elems) => { + Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect()) } - TypeArg::Tuple { elems } => TypeArg::Tuple { - elems: elems.iter().map(|elem| elem.substitute(t)).collect(), - }, - TypeArg::Variable { - v: TypeArgVariable { idx, cached_decl }, - } => t.apply_var(*idx, cached_decl), + Term::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl), + Term::RuntimeType { .. } => self.clone(), + Term::BoundedNatType { .. } => self.clone(), + Term::StringType => self.clone(), + Term::BytesType => self.clone(), + Term::FloatType => self.clone(), + Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)), + Term::TupleType(params) => { + Term::TupleType(params.iter().map(|p| p.substitute(t)).collect()) + } + Term::StaticType => self.clone(), } } } -impl Transformable for TypeArg { +impl Transformable for Term { fn transform(&mut self, tr: &T) -> Result { match self { - TypeArg::Type { ty } => ty.transform(tr), - TypeArg::List { elems } => elems.transform(tr), - TypeArg::Tuple { elems } => elems.transform(tr), - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Variable { .. } - | TypeArg::Float { .. } - | TypeArg::Bytes { .. } => Ok(false), + Term::Runtime(ty) => ty.transform(tr), + Term::List(elems) => elems.transform(tr), + Term::Tuple(elems) => elems.transform(tr), + Term::BoundedNat(_) + | Term::String(_) + | Term::Variable(_) + | Term::Float(_) + | Term::Bytes(_) => Ok(false), + Term::RuntimeType { .. } => Ok(false), + Term::BoundedNatType { .. } => Ok(false), + Term::StringType => Ok(false), + Term::BytesType => Ok(false), + Term::FloatType => Ok(false), + Term::ListType(item_type) => item_type.transform(tr), + Term::TupleType(item_types) => item_types.transform(tr), + Term::StaticType => Ok(false), } } } -impl TypeArgVariable { +impl TermVar { /// Return the index. #[must_use] pub fn index(&self) -> usize { @@ -414,8 +451,8 @@ impl TypeArgVariable { /// the [`TypeBound`] of the individual types it might stand for. #[must_use] pub fn bound_if_row_var(&self) -> Option { - if let TypeParam::List { param } = &self.cached_decl { - if let TypeParam::Type { b } = **param { + if let Term::ListType(item_type) = &*self.cached_decl { + if let Term::RuntimeType(b) = **item_type { return Some(b); } } @@ -423,82 +460,84 @@ impl TypeArgVariable { } } -/// Checks a [`TypeArg`] is as expected for a [`TypeParam`] -pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgError> { - match (arg, param) { - ( - TypeArg::Variable { - v: TypeArgVariable { cached_decl, .. }, - }, - _, - ) if param.contains(cached_decl) => Ok(()), - (TypeArg::Type { ty }, TypeParam::Type { b: bound }) - if bound.contains(ty.least_upper_bound()) => - { +/// Checks that a [`Term`] is valid for a given type. +pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> { + match (term, type_) { + (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => { + Ok(()) + } + (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => { Ok(()) } - (TypeArg::List { elems }, TypeParam::List { param }) => { - elems.iter().try_for_each(|arg| { + (Term::List(elems), Term::ListType(item_type)) => { + elems.iter().try_for_each(|term| { // Also allow elements that are RowVars if fitting into a List of Types - if let (TypeArg::Variable { v }, TypeParam::Type { b: param_bound }) = - (arg, &**param) - { + if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) { if v.bound_if_row_var() .is_some_and(|arg_bound| param_bound.contains(arg_bound)) { return Ok(()); } } - check_type_arg(arg, param) + check_term_type(term, item_type) }) } - (TypeArg::Tuple { elems: items }, TypeParam::Tuple { params: types }) => { - if items.len() != types.len() { - return Err(TypeArgError::WrongNumberTuple(items.len(), types.len())); + (Term::Tuple(items), Term::TupleType(item_types)) => { + if items.len() != item_types.len() { + return Err(TermTypeError::WrongNumberTuple( + items.len(), + item_types.len(), + )); } items .iter() - .zip(types.iter()) - .try_for_each(|(arg, param)| check_type_arg(arg, param)) + .zip(item_types.iter()) + .try_for_each(|(term, type_)| check_term_type(term, type_)) } - (TypeArg::BoundedNat { n: val }, TypeParam::BoundedNat { bound }) - if bound.valid_value(*val) => - { - Ok(()) - } - - (TypeArg::String { .. }, TypeParam::String) => Ok(()), - (TypeArg::Bytes { .. }, TypeParam::Bytes) => Ok(()), - (TypeArg::Float { .. }, TypeParam::Float) => Ok(()), - _ => Err(TypeArgError::TypeMismatch { - arg: arg.clone(), - param: param.clone(), + (Term::BoundedNat(val), Term::BoundedNatType(bound)) if bound.valid_value(*val) => Ok(()), + (Term::String { .. }, Term::StringType) => Ok(()), + (Term::Bytes(_), Term::BytesType) => Ok(()), + (Term::Float(_), Term::FloatType) => Ok(()), + + // Static types + (Term::RuntimeType(_), Term::StaticType) => Ok(()), + (Term::StaticType, Term::StaticType) => Ok(()), + (Term::StringType, Term::StaticType) => Ok(()), + (Term::BytesType, Term::StaticType) => Ok(()), + (Term::BoundedNatType { .. }, Term::StaticType) => Ok(()), + (Term::FloatType, Term::StaticType) => Ok(()), + (Term::ListType { .. }, Term::StaticType) => Ok(()), + (Term::TupleType(_), Term::StaticType) => Ok(()), + + _ => Err(TermTypeError::TypeMismatch { + term: term.clone(), + type_: type_.clone(), }), } } -/// Check a list of type arguments match a list of required type parameters -pub fn check_type_args(args: &[TypeArg], params: &[TypeParam]) -> Result<(), TypeArgError> { - if args.len() != params.len() { - return Err(TypeArgError::WrongNumberArgs(args.len(), params.len())); +/// Check a list of [`Term`]s is valid for a list of types. +pub fn check_term_types(terms: &[Term], types: &[Term]) -> Result<(), TermTypeError> { + if terms.len() != types.len() { + return Err(TermTypeError::WrongNumberArgs(terms.len(), types.len())); } - for (a, p) in args.iter().zip(params.iter()) { - check_type_arg(a, p)?; + for (term, type_) in terms.iter().zip(types.iter()) { + check_term_type(term, type_)?; } Ok(()) } -/// Errors that can occur fitting a [`TypeArg`] into a [`TypeParam`] +/// Errors that can occur when checking that a [`Term`] has an expected type. #[derive(Clone, Debug, PartialEq, Eq, Error)] #[non_exhaustive] -pub enum TypeArgError { +pub enum TermTypeError { #[allow(missing_docs)] - /// For now, general case of a type arg not fitting a param. + /// For now, general case of a term not fitting a type. /// We'll have more cases when we allow general Containers. // TODO It may become possible to combine this with ConstTypeError. - #[error("Type argument {arg} does not fit declared parameter {param}")] - TypeMismatch { param: TypeParam, arg: TypeArg }, + #[error("Term {term} does not fit declared type {type_}")] + TypeMismatch { term: Term, type_: Term }, /// Wrong number of type arguments (actual vs expected). // For now this only happens at the top level (TypeArgs of op/type vs TypeParams of Op/TypeDef). // However in the future it may be applicable to e.g. contents of Tuples too. @@ -518,53 +557,31 @@ pub enum TypeArgError { InvalidValue(TypeArg), } -/// Helper for to serialize and deserialize the byte string in `TypeArg::Bytes` via base64. -mod base64 { - use std::sync::Arc; - - use base64::Engine as _; - use base64::prelude::BASE64_STANDARD; - use serde::{Deserialize, Serialize}; - use serde::{Deserializer, Serializer}; - - pub fn serialize(v: &Arc<[u8]>, s: S) -> Result { - let base64 = BASE64_STANDARD.encode(v); - base64.serialize(s) - } - - pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result, D::Error> { - let base64 = String::deserialize(d)?; - BASE64_STANDARD - .decode(base64.as_bytes()) - .map(|v| v.into()) - .map_err(serde::de::Error::custom) - } -} - #[cfg(test)] mod test { use itertools::Itertools; - use super::{Substitution, TypeArg, TypeParam, check_type_arg}; + use super::{Substitution, TypeArg, TypeParam, check_term_type}; use crate::extension::prelude::{bool_t, usize_t}; - use crate::types::{TypeBound, TypeRV, type_param::TypeArgError}; + use crate::types::Term; + use crate::types::{TypeBound, TypeRV, type_param::TermTypeError}; #[test] fn type_arg_fits_param() { let rowvar = TypeRV::new_row_var_use; - fn check(arg: impl Into, param: &TypeParam) -> Result<(), TypeArgError> { - check_type_arg(&arg.into(), param) + fn check(arg: impl Into, param: &TypeParam) -> Result<(), TermTypeError> { + check_term_type(&arg.into(), param) } fn check_seq>( args: &[T], param: &TypeParam, - ) -> Result<(), TypeArgError> { + ) -> Result<(), TermTypeError> { let arg = args.iter().cloned().map_into().collect_vec().into(); - check_type_arg(&arg, param) + check_term_type(&arg, param) } - // Simple cases: a TypeArg::Type is a TypeParam::Type but singleton sequences are lists + // Simple cases: a Term::Type is a Term::RuntimeType but singleton sequences are lists check(usize_t(), &TypeBound::Copyable.into()).unwrap(); - let seq_param = TypeParam::new_list(TypeBound::Copyable); + let seq_param = TypeParam::new_list_type(TypeBound::Copyable); check(usize_t(), &seq_param).unwrap_err(); check_seq(&[usize_t()], &TypeBound::Any.into()).unwrap_err(); @@ -579,7 +596,7 @@ mod test { usize_t().into(), rowvar(0, TypeBound::Copyable), ], - &TypeParam::new_list(TypeBound::Any), + &TypeParam::new_list_type(TypeBound::Any), ) .unwrap(); // Next one fails because a list of Eq is required @@ -600,9 +617,9 @@ mod test { .unwrap_err(); // Similar for nats (but no equivalent of fancy row vars) - check(5, &TypeParam::max_nat()).unwrap(); - check_seq(&[5], &TypeParam::max_nat()).unwrap_err(); - let list_of_nat = TypeParam::new_list(TypeParam::max_nat()); + check(5, &TypeParam::max_nat_type()).unwrap(); + check_seq(&[5], &TypeParam::max_nat_type()).unwrap_err(); + let list_of_nat = TypeParam::new_list_type(TypeParam::max_nat_type()); check(5, &list_of_nat).unwrap_err(); check_seq(&[5], &list_of_nat).unwrap(); check(TypeArg::new_var_use(0, list_of_nat.clone()), &list_of_nat).unwrap(); @@ -613,27 +630,20 @@ mod test { ) .unwrap_err(); - // TypeParam::Tuples require a TypeArg::Tuple of the same number of elems - let usize_and_ty = TypeParam::Tuple { - params: vec![TypeParam::max_nat(), TypeBound::Copyable.into()], - }; + // `Term::TupleType` requires a `Term::Tuple` of the same number of elems + let usize_and_ty = + TypeParam::TupleType(vec![TypeParam::max_nat_type(), TypeBound::Copyable.into()]); check( - TypeArg::Tuple { - elems: vec![5.into(), usize_t().into()], - }, + TypeArg::Tuple(vec![5.into(), usize_t().into()]), &usize_and_ty, ) .unwrap(); check( - TypeArg::Tuple { - elems: vec![usize_t().into(), 5.into()], - }, + TypeArg::Tuple(vec![usize_t().into(), 5.into()]), &usize_and_ty, ) .unwrap_err(); // Wrong way around - let two_types = TypeParam::Tuple { - params: vec![TypeBound::Any.into(), TypeBound::Any.into()], - }; + let two_types = TypeParam::TupleType(vec![TypeBound::Any.into(), TypeBound::Any.into()]); check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); // not a Row Var which could have any number of elems check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); @@ -641,84 +651,76 @@ mod test { #[test] fn type_arg_subst_row() { - let row_param = TypeParam::new_list(TypeBound::Copyable); - let row_arg: TypeArg = vec![bool_t().into(), TypeArg::UNIT].into(); - check_type_arg(&row_arg, &row_param).unwrap(); + let row_param = Term::new_list_type(TypeBound::Copyable); + let row_arg: Term = vec![bool_t().into(), Term::UNIT].into(); + check_term_type(&row_arg, &row_param).unwrap(); // Now say a row variable referring to *that* row was used // to instantiate an outer "row parameter" (list of type). - let outer_param = TypeParam::new_list(TypeBound::Any); - let outer_arg = TypeArg::List { - elems: vec![ - TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), - usize_t().into(), - ], - }; - check_type_arg(&outer_arg, &outer_param).unwrap(); + let outer_param = Term::new_list_type(TypeBound::Any); + let outer_arg = Term::new_list([ + TypeRV::new_row_var_use(0, TypeBound::Copyable).into(), + usize_t().into(), + ]); + check_term_type(&outer_arg, &outer_param).unwrap(); let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); assert_eq!( outer_arg2, - vec![bool_t().into(), TypeArg::UNIT, usize_t().into()].into() + vec![bool_t().into(), Term::UNIT, usize_t().into()].into() ); // Of course this is still valid (as substitution is guaranteed to preserve validity) - check_type_arg(&outer_arg2, &outer_param).unwrap(); + check_term_type(&outer_arg2, &outer_param).unwrap(); } #[test] fn subst_list_list() { - let outer_param = TypeParam::new_list(TypeParam::new_list(TypeBound::Any)); - let row_var_decl = TypeParam::new_list(TypeBound::Copyable); - let row_var_use = TypeArg::new_var_use(0, row_var_decl.clone()); - let good_arg = TypeArg::List { - elems: vec![ - // The row variables here refer to `row_var_decl` above - vec![usize_t().into()].into(), - row_var_use.clone(), - vec![row_var_use, usize_t().into()].into(), - ], - }; - check_type_arg(&good_arg, &outer_param).unwrap(); + let outer_param = Term::new_list_type(Term::new_list_type(TypeBound::Any)); + let row_var_decl = Term::new_list_type(TypeBound::Copyable); + let row_var_use = Term::new_var_use(0, row_var_decl.clone()); + let good_arg = Term::new_list([ + // The row variables here refer to `row_var_decl` above + vec![usize_t().into()].into(), + row_var_use.clone(), + vec![row_var_use, usize_t().into()].into(), + ]); + check_term_type(&good_arg, &outer_param).unwrap(); // Outer list cannot include single types: - let TypeArg::List { mut elems } = good_arg.clone() else { + let Term::List(mut elems) = good_arg.clone() else { panic!() }; elems.push(usize_t().into()); assert_eq!( - check_type_arg(&TypeArg::List { elems }, &outer_param), - Err(TypeArgError::TypeMismatch { - arg: usize_t().into(), + check_term_type(&Term::new_list(elems), &outer_param), + Err(TermTypeError::TypeMismatch { + term: usize_t().into(), // The error reports the type expected for each element of the list: - param: TypeParam::new_list(TypeBound::Any) + type_: TypeParam::new_list_type(TypeBound::Any) }) ); // Now substitute a list of two types for that row-variable let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); - check_type_arg(&row_var_arg, &row_var_decl).unwrap(); + check_term_type(&row_var_arg, &row_var_decl).unwrap(); let subst_arg = good_arg.substitute(&Substitution(&[row_var_arg.clone()])); - check_type_arg(&subst_arg, &outer_param).unwrap(); // invariance of substitution + check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, - TypeArg::List { - elems: vec![ - vec![usize_t().into()].into(), - row_var_arg, - vec![usize_t().into(), bool_t().into(), usize_t().into()].into() - ] - } + Term::new_list([ + Term::new_list([usize_t().into()]), + row_var_arg, + Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()]) + ]) ); } #[test] fn bytes_json_roundtrip() { - let bytes_arg = TypeArg::Bytes { - value: vec![0, 1, 2, 3, 255, 254, 253, 252].into(), - }; + let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into()); let serialized = serde_json::to_string(&bytes_arg).unwrap(); - let deserialized: TypeArg = serde_json::from_str(&serialized).unwrap(); + let deserialized: Term = serde_json::from_str(&serialized).unwrap(); assert_eq!(deserialized, bytes_arg); } @@ -726,44 +728,66 @@ mod test { use proptest::prelude::*; - use super::super::{TypeArg, TypeArgVariable, TypeParam, UpperBound}; + use super::super::{TermVar, UpperBound}; use crate::proptest::RecursionDepth; - use crate::types::{Type, TypeBound}; + use crate::types::{Term, Type, TypeBound}; - impl Arbitrary for TypeArgVariable { + impl Arbitrary for TermVar { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - (any::(), any_with::(depth)) - .prop_map(|(idx, cached_decl)| Self { idx, cached_decl }) + (any::(), any_with::(depth)) + .prop_map(|(idx, cached_decl)| Self { + idx, + cached_decl: Box::new(cached_decl), + }) .boxed() } } - impl Arbitrary for TypeParam { + impl Arbitrary for Term { type Parameters = RecursionDepth; type Strategy = BoxedStrategy; fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { use prop::collection::vec; use prop::strategy::Union; let mut strat = Union::new([ - Just(Self::String).boxed(), - Just(Self::Bytes).boxed(), - Just(Self::Float).boxed(), - Just(Self::String).boxed(), - any::().prop_map(|b| Self::Type { b }).boxed(), - any::() - .prop_map(|bound| Self::BoundedNat { bound }) + Just(Self::StringType).boxed(), + Just(Self::BytesType).boxed(), + Just(Self::FloatType).boxed(), + Just(Self::StringType).boxed(), + any::().prop_map(Self::from).boxed(), + any::().prop_map(Self::from).boxed(), + any::().prop_map(Self::from).boxed(), + any::().prop_map(Self::from).boxed(), + any::>() + .prop_map(|bytes| Self::Bytes(bytes.into())) + .boxed(), + any::() + .prop_map(|value| Self::Float(value.into())) .boxed(), + any_with::(depth).prop_map(Self::from).boxed(), ]); if !depth.leaf() { - // we descend here because we these constructors contain TypeParams + // we descend here because we these constructors contain Terms strat = strat + .or( + // TODO this is a bit dodgy, TypeArgVariables are supposed + // to be constructed from TypeArg::new_var_use. We are only + // using this instance for serialization now, but if we want + // to generate valid TypeArgs this will need to change. + any_with::(depth.descend()) + .prop_map(Self::Variable) + .boxed(), + ) .or(any_with::(depth.descend()) - .prop_map(|x| Self::List { param: Box::new(x) }) + .prop_map(Self::new_list_type) .boxed()) .or(vec(any_with::(depth.descend()), 0..3) - .prop_map(|params| Self::Tuple { params }) + .prop_map(Self::new_tuple_type) + .boxed()) + .or(vec(any_with::(depth.descend()), 0..3) + .prop_map(Term::new_list) .boxed()); } @@ -771,43 +795,11 @@ mod test { } } - impl Arbitrary for TypeArg { - type Parameters = RecursionDepth; - type Strategy = BoxedStrategy; - fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy { - use prop::collection::vec; - use prop::strategy::Union; - let mut strat = Union::new([ - any::().prop_map(|n| Self::BoundedNat { n }).boxed(), - any::().prop_map(|arg| Self::String { arg }).boxed(), - any::>() - .prop_map(|bytes| Self::Bytes { - value: bytes.into(), - }) - .boxed(), - any::() - .prop_map(|value| Self::Float { - value: value.into(), - }) - .boxed(), - any_with::(depth) - .prop_map(|ty| Self::Type { ty }) - .boxed(), - // TODO this is a bit dodgy, TypeArgVariables are supposed - // to be constructed from TypeArg::new_var_use. We are only - // using this instance for serialization now, but if we want - // to generate valid TypeArgs this will need to change. - any_with::(depth) - .prop_map(|v| Self::Variable { v }) - .boxed(), - ]); - if !depth.leaf() { - // We descend here because this constructor contains TypeArg> - strat = strat.or(vec(any_with::(depth.descend()), 0..3) - .prop_map(|elems| Self::List { elems }) - .boxed()); - } - strat.boxed() + proptest! { + #[test] + fn type_term_contains_itself(term: Term) { + let is_type = term.validate_param().is_ok(); + assert_eq!(is_type, term.is_supertype(&term)); } } } diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 2f9e49026c..da5141d72f 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -214,7 +214,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index 2bddff687e..746ce11bc0 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -203,7 +203,7 @@ fn emit_list_op<'c, H: HugrView>( op: ListOp, ) -> Result<()> { let hugr_elem_ty = match args.node().args() { - [TypeArg::Type { ty }] => ty.clone(), + [TypeArg::Runtime(ty)] => ty.clone(), _ => { bail!("Collections: invalid type args for list op"); } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap index 1af774422e..ad17a2c59f 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@llvm14.snap @@ -15,14 +15,14 @@ source_filename = "test_context" @sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } @sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } @sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.c4a5911a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } define i64 @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.c4a5911a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 + %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), i32 0, i32 0 %1 = load i64, i64* %0, align 4 ret i64 %1 } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap index be8b63018c..b0f0741226 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__emit_static_array_of_static_array@pre-mem2reg@llvm14.snap @@ -15,7 +15,7 @@ source_filename = "test_context" @sa.inner.7f5d5e16.0 = constant { i64, [7 x i64] } { i64 7, [7 x i64] [i64 7, i64 7, i64 7, i64 7, i64 7, i64 7, i64 7] } @sa.inner.a0bc9c53.0 = constant { i64, [8 x i64] } { i64 8, [8 x i64] [i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8] } @sa.inner.1e8aada3.0 = constant { i64, [9 x i64] } { i64 9, [9 x i64] [i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9, i64 9] } -@sa.outer.c4a5911a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } +@sa.outer.e55b610a.0 = constant { i64, [10 x { i64, [0 x i64] }*] } { i64 10, [10 x { i64, [0 x i64] }*] [{ i64, [0 x i64] }* @sa.inner.6acc1b76.0, { i64, [0 x i64] }* bitcast ({ i64, [1 x i64] }* @sa.inner.e637bb5.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [2 x i64] }* @sa.inner.2b6593f.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [3 x i64] }* @sa.inner.1b9ad7c.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [4 x i64] }* @sa.inner.e67fbfa4.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [5 x i64] }* @sa.inner.15dc27f6.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [6 x i64] }* @sa.inner.c43a2bb2.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [7 x i64] }* @sa.inner.7f5d5e16.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [8 x i64] }* @sa.inner.a0bc9c53.0 to { i64, [0 x i64] }*), { i64, [0 x i64] }* bitcast ({ i64, [9 x i64] }* @sa.inner.1e8aada3.0 to { i64, [0 x i64] }*)] } define i64 @_hl.main.1() { alloca_block: @@ -25,7 +25,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.c4a5911a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 + store { i64, [0 x { i64, [0 x i64] }*] }* bitcast ({ i64, [10 x { i64, [0 x i64] }*] }* @sa.outer.e55b610a.0 to { i64, [0 x { i64, [0 x i64] }*] }*), { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %"5_01" = load { i64, [0 x { i64, [0 x i64] }*] }*, { i64, [0 x { i64, [0 x i64] }*] }** %"5_0", align 8 %0 = getelementptr inbounds { i64, [0 x { i64, [0 x i64] }*] }, { i64, [0 x { i64, [0 x i64] }*] }* %"5_01", i32 0, i32 0 %1 = load i64, i64* %0, align 4 diff --git a/hugr-llvm/src/extension/collections/stack_array.rs b/hugr-llvm/src/extension/collections/stack_array.rs index eaa3151ac3..297f539511 100644 --- a/hugr-llvm/src/extension/collections/stack_array.rs +++ b/hugr-llvm/src/extension/collections/stack_array.rs @@ -126,7 +126,7 @@ impl CodegenExtension for ArrayCodegenExtension { .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), { let ccg = self.0.clone(); move |ts, hugr_type| { - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else { return Err(anyhow!("Invalid type args for array type")); }; let elem_ty = ts.llvm_type(ty)?; diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 1d9bfd8147..50ac99b723 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -370,7 +370,7 @@ impl CodegenExtension for StaticArrayCodegenE let sac = self.0.clone(); move |ts, custom_type| { let element_type = custom_type.args()[0] - .as_type() + .as_runtime() .expect("Type argument for static array must be a type"); sac.static_array_type(ts, &element_type) } diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index 7f8932f00d..bea508d774 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -668,7 +668,7 @@ fn emit_int_op<'c, H: HugrView>( ]) }), IntOpDef::inarrow_s => { - let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned() + let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned() else { bail!("Type arg to inarrow_s wasn't a Nat"); }; @@ -686,7 +686,7 @@ fn emit_int_op<'c, H: HugrView>( }) } IntOpDef::inarrow_u => { - let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned() + let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned() else { bail!("Type arg to inarrow_u wasn't a Nat"); }; @@ -756,7 +756,7 @@ pub(crate) fn get_width_arg>( args: &EmitOpArgs<'_, '_, ExtensionOp, H>, op: &impl MakeExtensionOp, ) -> Result { - let [TypeArg::BoundedNat { n: log_width }] = args.node.args() else { + let [TypeArg::BoundedNat(log_width)] = args.node.args() else { bail!( "Expected exactly one BoundedNat parameter to {}", op.op_id() @@ -1094,7 +1094,7 @@ fn llvm_type<'c>( context: TypingSession<'c, '_>, hugr_type: &CustomType, ) -> Result> { - if let [TypeArg::BoundedNat { n }] = hugr_type.args() { + if let [TypeArg::BoundedNat(n)] = hugr_type.args() { let m = *n as usize; if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() { return Ok(match m { diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index bd52f7b514..3a411b4897 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -389,7 +389,7 @@ pub fn add_prelude_extensions<'a, H: HugrView + 'a>( move |context, args| { let load_nat = LoadNat::from_extension_op(args.node().as_ref())?; let v = match load_nat.get_nat() { - TypeArg::BoundedNat { n } => pcg + TypeArg::BoundedNat(n) => pcg .usize_type(&context.typing_session()) .const_int(n, false), arg => bail!("Unexpected type arg for LoadNat: {}", arg), @@ -408,7 +408,7 @@ mod test { use hugr_core::builder::{Dataflow, DataflowHugr}; use hugr_core::extension::PRELUDE; use hugr_core::extension::prelude::{EXIT_OP_ID, Noop}; - use hugr_core::types::{Type, TypeArg}; + use hugr_core::types::{Term, Type}; use hugr_core::{Hugr, type_row}; use prelude::{PANIC_OP_ID, PRINT_OP_ID, bool_t, qb_t, usize_t}; use rstest::{fixture, rstest}; @@ -559,10 +559,8 @@ mod test { #[rstest] fn prelude_panic(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "PANIC"); - let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; - let type_arg_2q: TypeArg = TypeArg::List { - elems: vec![type_arg_q.clone(), type_arg_q], - }; + let type_arg_q: Term = qb_t().into(); + let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); let panic_op = PRELUDE .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); @@ -587,10 +585,8 @@ mod test { #[rstest] fn prelude_exit(prelude_llvm_ctx: TestContext) { let error_val = ConstError::new(42, "EXIT"); - let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() }; - let type_arg_2q: TypeArg = TypeArg::List { - elems: vec![type_arg_q.clone(), type_arg_q], - }; + let type_arg_q: Term = qb_t().into(); + let type_arg_2q = Term::new_list([type_arg_q.clone(), type_arg_q]); let exit_op = PRELUDE .instantiate_extension_op(&EXIT_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); @@ -635,7 +631,7 @@ mod test { .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) .finish(|mut builder| { let v = builder - .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![]) + .add_dataflow_op(LoadNat::new(42u64.into()), vec![]) .unwrap() .out_wire(0); builder.finish_hugr_with_outputs([v]).unwrap() diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index 4f8da110c9..64daf21f4d 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -66,7 +66,7 @@ impl Default for LinearizeArrayPass { // error out and make sure we're not emitting `get`s for nested value // arrays. assert!( - op_def != ArrayOpDef::get || args[1].as_type().unwrap().copyable(), + op_def != ArrayOpDef::get || args[1].as_runtime().unwrap().copyable(), "Cannot linearise arrays in this Hugr: \ Contains a `get` operation on nested value arrays" ); diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index c505cd9779..a23d152962 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -232,14 +232,14 @@ fn escape_dollar(str: impl AsRef) -> String { fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match arg { - TypeArg::Type { ty } => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), - TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), - TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), - TypeArg::List { elems } => f.write_fmt(format_args!("list({})", TypeArgsSeq(elems))), - TypeArg::Tuple { elems } => f.write_fmt(format_args!("tuple({})", TypeArgsSeq(elems))), + TypeArg::Runtime(ty) => f.write_fmt(format_args!("t({})", escape_dollar(ty.to_string()))), + TypeArg::BoundedNat(n) => f.write_fmt(format_args!("n({n})")), + TypeArg::String(arg) => f.write_fmt(format_args!("s({})", escape_dollar(arg))), + TypeArg::List(elems) => f.write_fmt(format_args!("list({})", TypeArgsSeq(elems))), + TypeArg::Tuple(elems) => f.write_fmt(format_args!("tuple({})", TypeArgsSeq(elems))), // We are monomorphizing. We will never monomorphize to a signature // containing a variable. - TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), + TypeArg::Variable(_) => panic!("type_arg_str variable: {arg}"), _ => panic!("unknown type arg: {arg}"), } } @@ -410,8 +410,8 @@ mod test { //pf1 contains pf2 contains mono_func -> pf1 and pf1 share pf2's and they share mono_func let tv = |i| Type::new_var_use(i, TypeBound::Copyable); - let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat()); - let sa = |n| TypeArg::BoundedNat { n }; + let sv = |i| TypeArg::new_var_use(i, TypeParam::max_nat_type()); + let sa = |n| TypeArg::BoundedNat(n); let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", @@ -440,7 +440,7 @@ mod test { let pf2 = { let pf2t = PolyFuncType::new( - [TypeParam::max_nat(), TypeBound::Copyable.into()], + [TypeParam::max_nat_type(), TypeBound::Copyable.into()], Signature::new(ValueArray::ty_parametric(sv(0), tv(1)).unwrap(), tv(1)), ); let mut pf2 = mb.define_function("pf2", pf2t).unwrap(); @@ -457,7 +457,7 @@ mod test { }; let pf1t = PolyFuncType::new( - [TypeParam::max_nat()], + [TypeParam::max_nat_type()], Signature::new( ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), usize_t(), @@ -472,7 +472,7 @@ mod test { let elem = pf1 .call( pf2.handle(), - &[TypeArg::BoundedNat { n: 2 }, usize_t().into()], + &[TypeArg::BoundedNat(2), usize_t().into()], inner.outputs(), ) .unwrap(); @@ -509,11 +509,11 @@ mod test { assert_eq!( funcs.keys().copied().sorted().collect_vec(), vec![ - &mangle_name("pf1", &[TypeArg::BoundedNat { n: 5 }]), - &mangle_name("pf1", &[TypeArg::BoundedNat { n: 4 }]), - &mangle_name("pf2", &[TypeArg::BoundedNat { n: 5 }, arr2u().into()]), // from pf1<5> - &mangle_name("pf2", &[TypeArg::BoundedNat { n: 4 }, arr2u().into()]), // from pf1<4> - &mangle_name("pf2", &[TypeArg::BoundedNat { n: 2 }, usize_t().into()]), // from both pf1<4> and <5> + &mangle_name("pf1", &[TypeArg::BoundedNat(5)]), + &mangle_name("pf1", &[TypeArg::BoundedNat(4)]), + &mangle_name("pf2", &[TypeArg::BoundedNat(5), arr2u().into()]), // from pf1<5> + &mangle_name("pf2", &[TypeArg::BoundedNat(4), arr2u().into()]), // from pf1<4> + &mangle_name("pf2", &[TypeArg::BoundedNat(2), usize_t().into()]), // from both pf1<4> and <5> "get_usz", "pf2", "mainish", @@ -594,9 +594,9 @@ mod test { #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$list($n(0)$t(Unit))")] - #[case::sequence(vec![TypeArg::Tuple { elems: vec![0.into(), Type::UNIT.into()] }], "$foo$$tuple($n(0)$t(Unit))")] + #[case::sequence(vec![TypeArg::Tuple(vec![0.into(),Type::UNIT.into()])], "$foo$$tuple($n(0)$t(Unit))")] #[should_panic] - #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], + #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::StringType)], "$foo$$v(1)")] #[case::multiple(vec![0.into(), "arg".into()], "$foo$$n(0)$s(arg)")] fn test_mangle_name(#[case] args: Vec, #[case] expected: String) { diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ac19094c19..934e1d5b2b 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -641,7 +641,7 @@ mod test { } fn just_elem_type(args: &[TypeArg]) -> &Type { - let [TypeArg::Type { ty }] = args else { + let [TypeArg::Runtime(ty)] = args else { panic!("Expected just elem type") }; ty diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 25abb846bc..6bdb05b6f1 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -106,7 +106,7 @@ pub fn linearize_generic_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; if num_outports == 0 { @@ -307,7 +307,7 @@ pub fn copy_discard_array( ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp - let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { + let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = args else { panic!("Illegal TypeArgs to array: {args:?}") }; if ty.copyable() { diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index a682d754bb..3c69138863 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -404,7 +404,7 @@ mod test { arg_values: &[TypeArg], _def: &'o OpDef, ) -> Result { - let [TypeArg::BoundedNat { n }] = arg_values else { + let [TypeArg::BoundedNat(n)] = arg_values else { panic!() }; let outs = vec![self.0.clone(); *n as usize]; @@ -412,7 +412,7 @@ mod test { } fn static_params(&self) -> &[TypeParam] { - const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat()]; + const JUST_NAT: &[TypeParam] = &[TypeParam::max_nat_type()]; JUST_NAT } }