Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ pub(crate) mod test {
FunctionBuilder::new(
"bad_eval",
PolyFuncType::new(
[TypeParam::new_list(TypeBound::Copyable)],
[TypeParam::new_list_type(TypeBound::Copyable)],
Signature::new(
Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
vec![],
Expand Down
101 changes: 47 additions & 54 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Exporting HUGR graphs to their `hugr-model` representation.
use crate::extension::ExtensionRegistry;
use crate::hugr::internal::HugrInternals;
use crate::types::type_param::Term;
use crate::{
Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port,
extension::{ExtensionId, OpDef, SignatureFunc},
Expand All @@ -14,9 +15,7 @@ use crate::{
},
types::{
CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType,
TypeArg, TypeBase, TypeBound, TypeEnum,
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase,
},
};

Expand Down Expand Up @@ -385,7 +384,7 @@ impl<'a> Context<'a> {
let node = self.connected_function(node).unwrap();
let symbol = self.node_to_id[&node];
let mut args = BumpVec::new_in(self.bump);
args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg)));
args.extend(call.type_args.iter().map(|arg| self.export_term(arg, None)));
let args = args.into_bump_slice();
let func = self.make_term(table::Term::Apply(symbol, args));

Expand All @@ -401,7 +400,7 @@ impl<'a> Context<'a> {
let node = self.connected_function(node).unwrap();
let symbol = self.node_to_id[&node];
let mut args = BumpVec::new_in(self.bump);
args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg)));
args.extend(load.type_args.iter().map(|arg| self.export_term(arg, None)));
let args = args.into_bump_slice();
let func = self.make_term(table::Term::Apply(symbol, args));
let runtime_type = self.make_term(table::Term::Wildcard);
Expand Down Expand Up @@ -464,7 +463,7 @@ impl<'a> Context<'a> {
let node = self.export_opdef(op.def());
let params = self
.bump
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None)));
let operation = self.make_term(table::Term::Apply(node, params));
table::Operation::Custom(operation)
}
Expand All @@ -473,7 +472,7 @@ impl<'a> Context<'a> {
let node = self.make_named_global_ref(op.extension(), op.unqualified_id());
let params = self
.bump
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None)));
let operation = self.make_term(table::Term::Apply(node, params));
table::Operation::Custom(operation)
}
Expand Down Expand Up @@ -806,7 +805,7 @@ impl<'a> Context<'a> {

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let r#type = self.export_type_param(param, Some((scope, i as _)));
let r#type = self.export_term(param, Some((scope, i as _)));
let param = table::Param { name, r#type };
params.push(param);
}
Expand Down Expand Up @@ -854,40 +853,12 @@ impl<'a> Context<'a> {

let args = self
.bump
.alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p)));
.alloc_slice_fill_iter(t.args().iter().map(|p| self.export_term(p, None)));
let term = table::Term::Apply(symbol, args);
self.make_term(term)
}

pub fn export_type_arg(&mut self, t: &TypeArg) -> table::TermId {
match t {
TypeArg::Type { ty } => self.export_type(ty),
TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()),
TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()),
TypeArg::Float { value } => self.make_term(model::Literal::Float(*value).into()),
TypeArg::Bytes { value } => self.make_term(model::Literal::Bytes(value.clone()).into()),
TypeArg::List { elems } => {
// For now we assume that the sequence is meant to be a list.
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_type_arg(elem))),
);
self.make_term(table::Term::List(parts))
}
TypeArg::Tuple { elems } => {
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_type_arg(elem))),
);
self.make_term(table::Term::Tuple(parts))
}
TypeArg::Variable { v } => self.export_type_arg_var(v),
}
}

pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> table::TermId {
pub fn export_type_arg_var(&mut self, var: &TermVar) -> table::TermId {
let node = self.local_scope.expect("local variable out of scope");
self.make_term(table::Term::Var(table::VarId(node, var.index() as _)))
}
Expand Down Expand Up @@ -953,19 +924,19 @@ impl<'a> Context<'a> {
self.make_term(table::Term::List(parts))
}

/// Exports a `TypeParam` to a term.
/// Exports a term.
///
/// The `var` argument is set when the type parameter being exported is the
/// The `var` argument is set when the term being exported is the
/// type of a parameter to a polymorphic definition. In that case we can
/// generate a `nonlinear` constraint for the type of runtime types marked as
/// `TypeBound::Copyable`.
pub fn export_type_param(
pub fn export_term(
&mut self,
t: &TypeParam,
t: &Term,
var: Option<(table::NodeId, table::VarIndex)>,
) -> table::TermId {
match t {
TypeParam::Type { b } => {
Term::RuntimeType(b) => {
if let (Some((node, index)), TypeBound::Copyable) = (var, b) {
let term = self.make_term(table::Term::Var(table::VarId(node, index)));
let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]);
Expand All @@ -974,24 +945,46 @@ impl<'a> Context<'a> {

self.make_term_apply(model::CORE_TYPE, &[])
}
// This ignores the bound on the natural for now.
TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]),
TypeParam::Bytes => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
TypeParam::Float => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
TypeParam::List { param } => {
let item_type = self.export_type_param(param, None);
Term::BoundedNatType { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]),
Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
Term::ListType(item_type) => {
let item_type = self.export_term(item_type, None);
self.make_term_apply(model::CORE_LIST_TYPE, &[item_type])
}
TypeParam::Tuple { params } => {
let parts = self.bump.alloc_slice_fill_iter(
Term::TupleType(params) => {
let item_types = self.bump.alloc_slice_fill_iter(
params
.iter()
.map(|param| table::SeqPart::Item(self.export_type_param(param, None))),
.map(|param| table::SeqPart::Item(self.export_term(param, None))),
);
let types = self.make_term(table::Term::List(parts));
let types = self.make_term(table::Term::List(item_types));
self.make_term_apply(model::CORE_TUPLE_TYPE, &[types])
}
Term::Runtime(ty) => self.export_type(ty),
Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()),
Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()),
Term::Float(value) => self.make_term(model::Literal::Float(*value).into()),
Term::Bytes(value) => self.make_term(model::Literal::Bytes(value.clone()).into()),
Term::List(elems) => {
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_term(elem, None))),
);
self.make_term(table::Term::List(parts))
}
Term::Tuple(elems) => {
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_term(elem, None))),
);
self.make_term(table::Term::Tuple(parts))
}
Term::Variable(v) => self.export_type_arg_var(v),
Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]),
}
}

Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::hugr::IdentList;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{OpName, OpNameRef};
use crate::types::RowVariable;
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::type_param::{TermTypeError, TypeArg, TypeParam};
use crate::types::{CustomType, TypeBound, TypeName};
use crate::types::{Signature, TypeNameRef};

Expand Down Expand Up @@ -387,7 +387,7 @@ pub enum SignatureError {
ExtensionMismatch(ExtensionId, ExtensionId),
/// When the type arguments of the node did not match the params declared by the `OpDef`
#[error("Type arguments of node did not match params declared by definition: {0}")]
TypeArgMismatch(#[from] TypeArgError),
TypeArgMismatch(#[from] TermTypeError),
/// Invalid type arguments
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension/declarative/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ impl TypeParamDeclaration {
_extension: &Extension,
_ctx: DeclarationContext<'_>,
) -> Result<TypeParam, ExtensionDeclarationError> {
Ok(TypeParam::String)
Ok(TypeParam::StringType)
}
}
37 changes: 18 additions & 19 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::{
use crate::Hugr;
use crate::envelope::serde_with::AsStringEnvelope;
use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
mod serialize_signature_func;

Expand Down Expand Up @@ -239,7 +239,7 @@ impl SignatureFunc {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
check_term_types(static_args, static_params)?;
temp = func.compute_signature(static_args, def)?;
(&temp, other_args)
}
Expand Down Expand Up @@ -347,7 +347,7 @@ impl OpDef {
let (static_args, other_args) =
args.split_at(min(custom.static_params().len(), args.len()));
static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
check_type_args(static_args, custom.static_params())?;
check_term_types(static_args, custom.static_params())?;
temp = custom.compute_signature(static_args, self)?;
(&temp, other_args)
}
Expand All @@ -357,7 +357,7 @@ impl OpDef {
}
};
args.iter().try_for_each(|ta| ta.validate(var_decls))?;
check_type_args(args, pf.params())?;
check_term_types(args, pf.params())?;
Ok(())
}

Expand Down Expand Up @@ -553,7 +553,7 @@ pub(super) mod test {
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::ops::OpName;
use crate::std_extensions::collections::list;
use crate::types::type_param::{TypeArgError, TypeParam};
use crate::types::type_param::{TermTypeError, TypeParam};
use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
use crate::{Extension, const_extension_ids};

Expand Down Expand Up @@ -656,7 +656,7 @@ pub(super) mod test {
const OP_NAME: OpName = OpName::new_inline("Reverse");

let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any);
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var]));
Expand All @@ -678,11 +678,10 @@ pub(super) mod test {
reg.validate()?;
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?);
let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
let rev = dfg.add_dataflow_op(
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }])
e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
.unwrap(),
dfg.input_wires(),
)?;
Expand All @@ -703,8 +702,8 @@ pub(super) mod test {
&self,
arg_values: &[TypeArg],
) -> Result<PolyFuncTypeRV, SignatureError> {
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
let [TypeArg::BoundedNat { n }] = arg_values else {
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any);
let [TypeArg::BoundedNat(n)] = arg_values else {
return Err(SignatureError::InvalidTypeArgs);
};
let n = *n as usize;
Expand All @@ -718,7 +717,7 @@ pub(super) mod test {
}

fn static_params(&self) -> &[TypeParam] {
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()];
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
MAX_NAT
}
}
Expand All @@ -727,7 +726,7 @@ pub(super) mod test {
ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;

// Base case, no type variables:
let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()];
let args = [TypeArg::BoundedNat(3), usize_t().into()];
assert_eq!(
def.compute_signature(&args),
Ok(Signature::new(
Expand All @@ -740,7 +739,7 @@ pub(super) mod test {
// Second arg may be a variable (substitutable)
let tyvar = Type::new_var_use(0, TypeBound::Copyable);
let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()];
let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
assert_eq!(
def.compute_signature(&args),
Ok(Signature::new(
Expand All @@ -761,7 +760,7 @@ pub(super) mod test {
);

// First arg must be concrete, not a variable
let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap());
let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
// We can't prevent this from getting into our compute_signature implementation:
assert_eq!(
Expand Down Expand Up @@ -798,7 +797,7 @@ pub(super) mod test {
extension_ref,
)?;
let tv = Type::new_var_use(0, TypeBound::Copyable);
let args = [TypeArg::Type { ty: tv.clone() }];
let args = [tv.clone().into()];
let decls = [TypeBound::Copyable.into()];
def.validate_args(&args, &decls).unwrap();
assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv)));
Expand All @@ -807,9 +806,9 @@ pub(super) mod test {
assert_eq!(
def.compute_signature(&[arg.clone()]),
Err(SignatureError::TypeArgMismatch(
TypeArgError::TypeMismatch {
param: TypeBound::Any.into(),
arg
TermTypeError::TypeMismatch {
type_: TypeBound::Any.into(),
term: arg,
}
))
);
Expand Down
Loading
Loading