From 6bd094f6dcf4607ef9f58423ed45d06604728b92 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 20 Nov 2024 15:20:46 +0000 Subject: [PATCH] feat: Emulate `TypeBound`s on parameters via constraints. (#1624) This PR translates some `TypeBound`s in `hugr-core` to the `nonlinear` constraint in `hugr-model`. This translation only occurs on parameters that take a runtime type directly. As a driveby change before the model stabilises, this PR also moves constraints out of the parameter lists into their own list. In its previous form this could have led to confusions about which parameter a local variable index refers to when a constraint is situated between two parameters in the list. We also remove constraints from aliases for now. Closes https://github.com/CQCL/hugr/issues/1637. --- hugr-core/src/export.rs | 85 +++++++++++--- hugr-core/src/import.rs | 107 ++++++++++++------ hugr-core/tests/model.rs | 7 ++ .../model__roundtrip_constraints.snap | 16 +++ hugr-model/capnp/hugr-v0.capnp | 34 +++--- hugr-model/src/v0/binary/read.rs | 40 +++---- hugr-model/src/v0/binary/write.rs | 27 ++--- hugr-model/src/v0/mod.rs | 54 ++++----- hugr-model/src/v0/text/hugr.pest | 16 +-- hugr-model/src/v0/text/parse.rs | 41 +++++-- hugr-model/src/v0/text/print.rs | 100 ++++++++-------- hugr-model/tests/binary.rs | 5 + .../tests/fixtures/model-constraints.edn | 13 +++ 13 files changed, 349 insertions(+), 196 deletions(-) create mode 100644 hugr-core/tests/snapshots/model__roundtrip_constraints.snap create mode 100644 hugr-model/tests/fixtures/model-constraints.edn diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 32460ce5d..093368b60 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeEnum, + TypeBase, TypeBound, TypeEnum, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -44,8 +44,21 @@ struct Context<'a> { bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. term_map: FxHashMap, model::TermId>, + /// The current scope for local variables. + /// + /// This is set to the id of the smallest enclosing node that defines a polymorphic type. + /// We use this when exporting local variables in terms. local_scope: Option, + + /// Constraints to be added to the local scope. + /// + /// When exporting a node that defines a polymorphic type, we use this field + /// to collect the constraints that need to be added to that polymorphic + /// type. Currently this is used to record `nonlinear` constraints on uses + /// of `TypeParam::Type` with a `TypeBound::Copyable` bound. + local_constraints: Vec, + /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } @@ -63,6 +76,7 @@ impl<'a> Context<'a> { term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), + local_constraints: Vec::new(), } } @@ -173,9 +187,11 @@ impl<'a> Context<'a> { } fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { - let old_scope = self.local_scope.replace(node); + let prev_local_scope = self.local_scope.replace(node); + let prev_local_constraints = std::mem::take(&mut self.local_constraints); let result = f(self); - self.local_scope = old_scope; + self.local_scope = prev_local_scope; + self.local_constraints = prev_local_constraints; result } @@ -232,10 +248,11 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, signature) = this.export_poly_func_type(&func.signature); + let (params, constraints, signature) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); let extensions = this.export_ext_set(&func.signature.body().extension_reqs); @@ -247,10 +264,11 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, func) = this.export_poly_func_type(&func.signature); + let (params, constraints, func) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature: func, }); model::Operation::DeclareFunc { decl } @@ -450,10 +468,11 @@ impl<'a> Context<'a> { let decl = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension(), opdef.name()); - let (params, r#type) = this.export_poly_func_type(poly_func_type); + let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); decl @@ -671,22 +690,36 @@ impl<'a> Context<'a> { regions.into_bump_slice() } + /// Exports a polymorphic function type. + /// + /// The returned triple consists of: + /// - The static parameters of the polymorphic function type. + /// - The constraints of the polymorphic function type. + /// - The function type itself. pub fn export_poly_func_type( &mut self, t: &PolyFuncTypeBase, - ) -> (&'a [model::Param<'a>], model::TermId) { + ) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let scope = self + .local_scope + .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param); - let param = model::Param::Implicit { name, r#type }; + let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let param = model::Param { + name, + r#type, + sort: model::ParamSort::Implicit, + }; params.push(param) } + let constraints = self.bump.alloc_slice_copy(&self.local_constraints); let body = self.export_func_type(t.body()); - (params.into_bump_slice(), body) + (params.into_bump_slice(), constraints, body) } pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { @@ -703,7 +736,6 @@ impl<'a> Context<'a> { } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { - // This ignores the type bound for now let node = self.local_scope.expect("local variable out of scope"); self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) } @@ -794,20 +826,39 @@ impl<'a> Context<'a> { self.make_term(model::Term::List { items, tail: None }) } - pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + /// Exports a `TypeParam` to a term. + /// + /// The `var` argument is set when the type parameter being exported is the + /// type of a parameter to a polymorphic definition. In that case we can + /// generate a `nonlinear` constraint for the type of runtime types marked as + /// `TypeBound::Copyable`. + pub fn export_type_param( + &mut self, + t: &TypeParam, + var: Option>, + ) -> model::TermId { match t { - // This ignores the type bound for now. - TypeParam::Type { .. } => self.make_term(model::Term::Type), - // This ignores the type bound for now. + TypeParam::Type { b } => { + if let (Some(var), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(var)); + let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); + self.local_constraints.push(non_linear); + } + + self.make_term(model::Term::Type) + } + // This ignores the bound on the natural for now. TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { - let item_type = self.export_type_param(param); + let item_type = self.export_type_param(param, None); self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( - params.iter().map(|param| self.export_type_param(param)), + params + .iter() + .map(|param| self.export_type_param(param, None)), ); let types = self.make_term(model::Term::List { items, tail: None }); self.make_term(model::Term::ApplyFull { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index e6a53cb2f..7619ad44a 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -115,8 +115,8 @@ struct Context<'a> { /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The types of the local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, model::TermId>, + /// The local variables that are currently in scope. + local_variables: FxIndexMap<&'a str, LocalVar>, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } @@ -155,20 +155,20 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope and returns its index and type. + /// Looks up a [`LocalRef`] within the current scope. fn resolve_local_ref( &self, local_ref: &model::LocalRef, - ) -> Result<(usize, model::TermId), ImportError> { + ) -> Result<(usize, LocalVar), ImportError> { let term = match local_ref { model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) - .map(|(_, term)| (*index as usize, *term)), + .map(|(_, v)| (*index as usize, *v)), model::LocalRef::Named(name) => self .local_variables .get_full(name) - .map(|(index, _, term)| (index, *term)), + .map(|(index, _, v)| (index, *v)), }; term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) @@ -898,41 +898,49 @@ impl<'a> Context<'a> { self.with_local_socpe(|ctx| { let mut imported_params = Vec::with_capacity(decl.params.len()); - for param in decl.params { - // TODO: `PolyFuncType` should be able to handle constraints - // and distinguish between implicit and explicit parameters. - match param { - model::Param::Implicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Explicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Constraint { constraint: _ } => { - return Err(error_unsupported!("constraints")); + ctx.local_variables.extend( + decl.params + .iter() + .map(|param| (param.name, LocalVar::new(param.r#type))), + ); + + for constraint in decl.constraints { + match ctx.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables[var].bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = ctx.local_variables[index].bound; + imported_params.push(ctx.import_type_param(param.r#type, bound)?); + } + let body = ctx.import_func_type::(decl.signature)?; in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) }) } /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param(&mut self, term_id: model::TermId) -> Result { + fn import_type_param( + &mut self, + term_id: model::TermId, + bound: TypeBound, + ) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => { - // As part of the migration from `TypeBound`s to constraints, we pretend that all - // `TypeBound`s are copyable. - Ok(TypeParam::Type { - b: TypeBound::Copyable, - }) - } + model::Term::Type => Ok(TypeParam::Type { b: bound }), model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), @@ -944,7 +952,9 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { - let param = Box::new(self.import_type_param(*item_type)?); + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); Ok(TypeParam::List { param }) } @@ -958,7 +968,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::ExtSet { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -966,7 +979,7 @@ impl<'a> Context<'a> { } } - /// Import a `TypeArg` froma term that represents a static type or value. + /// Import a `TypeArg` from a term that represents a static type or value. fn import_type_arg(&mut self, term_id: model::TermId) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -975,8 +988,8 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var_type) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var_type)?; + let (index, var) = self.resolve_local_ref(var)?; + let decl = self.import_type_param(var.r#type, var.bound)?; Ok(TypeArg::new_var_use(index, decl)) } @@ -1014,7 +1027,10 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1115,7 +1131,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType - | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Nat(_) + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1291,3 +1310,21 @@ impl<'a> Names<'a> { Ok(Self { items }) } } + +/// Information about a local variable. +#[derive(Debug, Clone, Copy)] +struct LocalVar { + /// The type of the variable. + r#type: model::TermId, + /// The type bound of the variable. + bound: TypeBound, +} + +impl LocalVar { + pub fn new(r#type: model::TermId) -> Self { + Self { + r#type, + bound: TypeBound::Any, + } + } +} diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 611eda660..d9ef0d2c9 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() { "../../hugr-model/tests/fixtures/model-params.edn" ))); } + +#[test] +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap new file mode 100644 index 000000000..f085c4785 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -0,0 +1,16 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +--- +(hugr 0) + +(declare-func array.replicate + (forall ?0 type) + (forall ?1 nat) + (where (nonlinear ?0)) + [?0] [(@ array.Array ?0 ?1)] (ext)) + +(declare-func array.copy + (forall ?0 type) + (where (nonlinear ?0)) + [(@ array.Array ?0)] [(@ array.Array ?0) (@ array.Array ?0)] (ext)) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 95db81205..94341beba 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -56,13 +56,15 @@ struct Operation { struct FuncDefn { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct FuncDecl { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct AliasDefn { @@ -81,13 +83,15 @@ struct Operation { struct ConstructorDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } struct OperationDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } } @@ -157,6 +161,7 @@ struct Term { funcType @17 :FuncType; control @18 :TermId; controlType @19 :Void; + nonLinearConstraint @20 :TermId; } struct Apply { @@ -187,19 +192,12 @@ struct Term { } struct Param { - union { - implicit @0 :Implicit; - explicit @1 :Explicit; - constraint @2 :TermId; - } - - struct Implicit { - name @0 :Text; - type @1 :TermId; - } + name @0 :Text; + type @1 :TermId; + sort @2 :ParamSort; +} - struct Explicit { - name @0 :Text; - type @1 :TermId; - } +enum ParamSort { + implicit @0; + explicit @1; } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 681bd4ea9..5381a7dc8 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -140,10 +140,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DefineFunc { decl } @@ -152,10 +154,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DeclareFunc { decl } @@ -189,10 +193,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::ConstructorDecl { name, params, + constraints, r#type, }); model::Operation::DeclareConstructor { decl } @@ -201,10 +207,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); model::Operation::DeclareOperation { decl } @@ -332,6 +340,10 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Control(values) => model::Term::Control { values: model::TermId(values), }, + + Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { + term: model::TermId(term), + }, }) } @@ -348,23 +360,13 @@ fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, ) -> ReadResult> { - use hugr_capnp::param::Which; - Ok(match reader.which()? { - Which::Implicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Implicit { name, r#type } - } - Which::Explicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Explicit { name, r#type } - } - Which::Constraint(constraint) => { - let constraint = model::TermId(constraint); - model::Param::Constraint { constraint } - } - }) + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + + let sort = match reader.get_sort()? { + hugr_capnp::ParamSort::Implicit => model::ParamSort::Implicit, + hugr_capnp::ParamSort::Explicit => model::ParamSort::Explicit, + }; + + Ok(model::Param { name, r#type, sort }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index a4b64d646..f3a0a14d2 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -60,12 +60,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_func_defn(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } model::Operation::DeclareFunc { decl } => { let mut builder = builder.init_func_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } @@ -87,12 +89,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_constructor_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } model::Operation::DeclareOperation { decl } => { let mut builder = builder.init_operation_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } @@ -101,19 +105,12 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { - match param { - model::Param::Implicit { name, r#type } => { - let mut builder = builder.init_implicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Explicit { name, r#type } => { - let mut builder = builder.init_explicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Constraint { constraint } => builder.set_constraint(constraint.0), - } + builder.set_name(param.name); + builder.set_type(param.r#type.0); + builder.set_sort(match param.sort { + model::ParamSort::Implicit => hugr_capnp::ParamSort::Implicit, + model::ParamSort::Explicit => hugr_capnp::ParamSort::Explicit, + }); } fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { @@ -212,5 +209,9 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_outputs(outputs.0); builder.set_extensions(extensions.0); } + + model::Term::NonLinearConstraint { term } => { + builder.set_non_linear_constraint(term.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index cb8713b32..16c7cb6c6 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -397,6 +397,8 @@ pub struct FuncDecl<'a> { pub name: &'a str, /// The static parameters of the function. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The signature of the function. pub signature: TermId, } @@ -419,6 +421,8 @@ pub struct ConstructorDecl<'a> { pub name: &'a str, /// The static parameters of the constructor. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the constructed term. pub r#type: TermId, } @@ -430,6 +434,8 @@ pub struct OperationDecl<'a> { pub name: &'a str, /// The static parameters of the operation. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the operation. This must be a function type. pub r#type: TermId, } @@ -662,6 +668,12 @@ pub enum Term<'a> { /// /// `ctrl : static` ControlType, + + /// Constraint that requires a runtime type to be copyable and discardable. + NonLinearConstraint { + /// The runtime type that must be copyable and discardable. + term: TermId, + }, } /// A parameter to a function or alias. @@ -669,33 +681,23 @@ pub enum Term<'a> { /// Parameter names must be unique within a parameter list. /// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Param<'a> { - /// An implicit parameter that should be inferred, unless a full application form is used +pub struct Param<'a> { + /// The name of the parameter. + pub name: &'a str, + /// The type of the parameter. + pub r#type: TermId, + /// The sort of the parameter (implicit or explicit). + pub sort: ParamSort, +} + +/// The sort of a parameter (implicit or explicit). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ParamSort { + /// The parameter is implicit and should be inferred, unless a full application form is used /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). - Implicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// An explicit parameter that should always be provided. - Explicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// A constraint that should be satisfied by other parameters in a parameter list. - Constraint { - /// The constraint to be satisfied. - /// - /// This must be a term of type `constraint`. - constraint: TermId, - }, + Implicit, + /// The parameter is explicit and should always be provided. + Explicit, } /// Errors that can occur when traversing and interpreting the model. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 132d78567..d05e3d774 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -56,16 +56,16 @@ node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ term ~ term ~ term } +func_header = { symbol ~ param* ~ where_clause* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } -ctr_header = { symbol ~ param* ~ term } -operation_header = { symbol ~ param* ~ term } +ctr_header = { symbol ~ param* ~ where_clause* ~ term } +operation_header = { symbol ~ param* ~ where_clause* ~ term } -param = { param_implicit | param_explicit | param_constraint } +param = { param_implicit | param_explicit } -param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } -param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } -param_constraint = { "(" ~ "where" ~ term ~ ")" } +param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } +param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } +where_clause = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } @@ -92,6 +92,7 @@ term = { | term_ctrl_type | term_apply_full | term_apply + | term_non_linear } term_wildcard = { "_" } @@ -114,3 +115,4 @@ term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } +term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index fa486454b..370dbeac0 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::v0::{ AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, Region, RegionId, RegionKind, Term, TermId, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -209,6 +209,11 @@ impl<'a> ParseContext<'a> { Term::Control { values } } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } + r => unreachable!("term: {:?}", r), }; @@ -544,6 +549,7 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; @@ -559,6 +565,7 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc(FuncDecl { name, params, + constraints, signature: func, })) } @@ -584,11 +591,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(ConstructorDecl { name, params, + constraints, r#type, })) } @@ -599,11 +608,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(OperationDecl { name, params, + constraints, r#type, })) } @@ -619,18 +630,21 @@ impl<'a> ParseContext<'a> { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Implicit { name, r#type } + Param { + name, + r#type, + sort: ParamSort::Implicit, + } } Rule::param_explicit => { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Explicit { name, r#type } - } - Rule::param_constraint => { - let mut inner = param.into_inner(); - let constraint = self.parse_term(inner.next().unwrap())?; - Param::Constraint { constraint } + Param { + name, + r#type, + sort: ParamSort::Explicit, + } } _ => unreachable!(), }; @@ -641,6 +655,17 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } + fn parse_constraints(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [TermId]> { + let mut constraints = Vec::new(); + + for pair in filter_rule(pairs, Rule::where_clause) { + let constraint = self.parse_term(pair.into_inner().next().unwrap())?; + constraints.push(constraint); + } + + Ok(self.bump.alloc_slice_copy(&constraints)) + } + fn parse_signature(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult> { let Some(Rule::signature) = pairs.peek().map(|p| p.as_rule()) else { return Ok(None); diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 01b9d7195..512f6d1e4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, RegionId, - RegionKind, Term, TermId, + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -122,15 +122,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { f: impl FnOnce(&mut Self) -> PrintResult, ) -> PrintResult { let locals = std::mem::take(&mut self.locals); - - for param in params { - match param { - Param::Implicit { name, .. } => self.locals.push(name), - Param::Explicit { name, .. } => self.locals.push(name), - Param::Constraint { .. } => {} - } - } - + self.locals.extend(params.iter().map(|param| param.name)); let result = f(self); self.locals = locals; result @@ -178,9 +170,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -208,9 +199,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -303,9 +293,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_term(*value)?; @@ -318,9 +306,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -333,9 +319,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -348,9 +333,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -384,10 +368,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } fn print_regions(&mut self, regions: &'a [RegionId]) -> PrintResult<()> { - for region in regions { - self.print_region(*region)?; - } - Ok(()) + regions + .iter() + .try_for_each(|region| self.print_region(*region)) } fn print_region(&mut self, region: RegionId) -> PrintResult<()> { @@ -422,11 +405,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_region(region) .ok_or(PrintError::RegionNotFound(region))?; - for node_id in region_data.children { - self.print_node(*node_id)?; - } - - Ok(()) + region_data + .children + .iter() + .try_for_each(|node_id| self.print_node(*node_id)) } fn print_port_lists( @@ -460,25 +442,33 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { + params.iter().try_for_each(|param| self.print_param(*param)) + } + fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { - self.print_parens(|this| match param { - Param::Implicit { name, r#type } => { - this.print_text("forall"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Explicit { name, r#type } => { - this.print_text("param"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Constraint { constraint } => { - this.print_text("where"); - this.print_term(constraint) - } + self.print_parens(|this| { + match param.sort { + ParamSort::Implicit => this.print_text("forall"), + ParamSort::Explicit => this.print_text("param"), + }; + + this.print_text(format!("?{}", param.name)); + this.print_term(param.r#type) }) } + fn print_constraints(&mut self, terms: &'a [TermId]) -> PrintResult<()> { + for term in terms { + self.print_parens(|this| { + this.print_text("where"); + this.print_term(*term) + })?; + } + + Ok(()) + } + fn print_term(&mut self, term_id: TermId) -> PrintResult<()> { let term_data = self .module @@ -598,6 +588,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("ctrl"); Ok(()) } + Term::NonLinearConstraint { term } => self.print_parens(|this| { + this.print_text("nonlinear"); + this.print_term(*term) + }), } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 043061677..93955fe6e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -51,3 +51,8 @@ pub fn test_params() { pub fn test_decl_exts() { binary_roundtrip(include_str!("fixtures/model-decl-exts.edn")); } + +#[test] +pub fn test_constraints() { + binary_roundtrip(include_str!("fixtures/model-constraints.edn")); +} diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn new file mode 100644 index 000000000..5db6b9886 --- /dev/null +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -0,0 +1,13 @@ +(hugr 0) + +(declare-func array.replicate + (forall ?t type) + (forall ?n nat) + (where (nonlinear ?t)) + [?t] [(@ array.Array ?t ?n)] + (ext)) + +(declare-func array.copy + (forall ?t type) + (where (nonlinear ?t)) + [(@ array.Array ?t)] [(@ array.Array ?t) (@ array.Array ?t)] (ext))