diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 32460ce5d..093368b60 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeEnum, + TypeBase, TypeBound, TypeEnum, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -44,8 +44,21 @@ struct Context<'a> { bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. term_map: FxHashMap, model::TermId>, + /// The current scope for local variables. + /// + /// This is set to the id of the smallest enclosing node that defines a polymorphic type. + /// We use this when exporting local variables in terms. local_scope: Option, + + /// Constraints to be added to the local scope. + /// + /// When exporting a node that defines a polymorphic type, we use this field + /// to collect the constraints that need to be added to that polymorphic + /// type. Currently this is used to record `nonlinear` constraints on uses + /// of `TypeParam::Type` with a `TypeBound::Copyable` bound. + local_constraints: Vec, + /// Mapping from extension operations to their declarations. decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } @@ -63,6 +76,7 @@ impl<'a> Context<'a> { term_map: FxHashMap::default(), local_scope: None, decl_operations: FxHashMap::default(), + local_constraints: Vec::new(), } } @@ -173,9 +187,11 @@ impl<'a> Context<'a> { } fn with_local_scope(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T { - let old_scope = self.local_scope.replace(node); + let prev_local_scope = self.local_scope.replace(node); + let prev_local_constraints = std::mem::take(&mut self.local_constraints); let result = f(self); - self.local_scope = old_scope; + self.local_scope = prev_local_scope; + self.local_constraints = prev_local_constraints; result } @@ -232,10 +248,11 @@ impl<'a> Context<'a> { OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, signature) = this.export_poly_func_type(&func.signature); + let (params, constraints, signature) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); let extensions = this.export_ext_set(&func.signature.body().extension_reqs); @@ -247,10 +264,11 @@ impl<'a> Context<'a> { OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| { let name = this.get_func_name(node).unwrap(); - let (params, func) = this.export_poly_func_type(&func.signature); + let (params, constraints, func) = this.export_poly_func_type(&func.signature); let decl = this.bump.alloc(model::FuncDecl { name, params, + constraints, signature: func, }); model::Operation::DeclareFunc { decl } @@ -450,10 +468,11 @@ impl<'a> Context<'a> { let decl = self.with_local_scope(node, |this| { let name = this.make_qualified_name(opdef.extension(), opdef.name()); - let (params, r#type) = this.export_poly_func_type(poly_func_type); + let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); decl @@ -671,22 +690,36 @@ impl<'a> Context<'a> { regions.into_bump_slice() } + /// Exports a polymorphic function type. + /// + /// The returned triple consists of: + /// - The static parameters of the polymorphic function type. + /// - The constraints of the polymorphic function type. + /// - The function type itself. pub fn export_poly_func_type( &mut self, t: &PolyFuncTypeBase, - ) -> (&'a [model::Param<'a>], model::TermId) { + ) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) { let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump); + let scope = self + .local_scope + .expect("exporting poly func type outside of local scope"); for (i, param) in t.params().iter().enumerate() { let name = self.bump.alloc_str(&i.to_string()); - let r#type = self.export_type_param(param); - let param = model::Param::Implicit { name, r#type }; + let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _))); + let param = model::Param { + name, + r#type, + sort: model::ParamSort::Implicit, + }; params.push(param) } + let constraints = self.bump.alloc_slice_copy(&self.local_constraints); let body = self.export_func_type(t.body()); - (params.into_bump_slice(), body) + (params.into_bump_slice(), constraints, body) } pub fn export_type(&mut self, t: &TypeBase) -> model::TermId { @@ -703,7 +736,6 @@ impl<'a> Context<'a> { } TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { - // This ignores the type bound for now let node = self.local_scope.expect("local variable out of scope"); self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _))) } @@ -794,20 +826,39 @@ impl<'a> Context<'a> { self.make_term(model::Term::List { items, tail: None }) } - pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId { + /// Exports a `TypeParam` to a term. + /// + /// The `var` argument is set when the type parameter being exported is the + /// type of a parameter to a polymorphic definition. In that case we can + /// generate a `nonlinear` constraint for the type of runtime types marked as + /// `TypeBound::Copyable`. + pub fn export_type_param( + &mut self, + t: &TypeParam, + var: Option>, + ) -> model::TermId { match t { - // This ignores the type bound for now. - TypeParam::Type { .. } => self.make_term(model::Term::Type), - // This ignores the type bound for now. + TypeParam::Type { b } => { + if let (Some(var), TypeBound::Copyable) = (var, b) { + let term = self.make_term(model::Term::Var(var)); + let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); + self.local_constraints.push(non_linear); + } + + self.make_term(model::Term::Type) + } + // This ignores the bound on the natural for now. TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType), TypeParam::String => self.make_term(model::Term::StrType), TypeParam::List { param } => { - let item_type = self.export_type_param(param); + let item_type = self.export_type_param(param, None); self.make_term(model::Term::ListType { item_type }) } TypeParam::Tuple { params } => { let items = self.bump.alloc_slice_fill_iter( - params.iter().map(|param| self.export_type_param(param)), + params + .iter() + .map(|param| self.export_type_param(param, None)), ); let types = self.make_term(model::Term::List { items, tail: None }); self.make_term(model::Term::ApplyFull { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index e6a53cb2f..7619ad44a 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -115,8 +115,8 @@ struct Context<'a> { /// A map from `NodeId` to the imported `Node`. nodes: FxHashMap, - /// The types of the local variables that are currently in scope. - local_variables: FxIndexMap<&'a str, model::TermId>, + /// The local variables that are currently in scope. + local_variables: FxIndexMap<&'a str, LocalVar>, custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>, } @@ -155,20 +155,20 @@ impl<'a> Context<'a> { .ok_or_else(|| model::ModelError::RegionNotFound(region_id).into()) } - /// Looks up a [`LocalRef`] within the current scope and returns its index and type. + /// Looks up a [`LocalRef`] within the current scope. fn resolve_local_ref( &self, local_ref: &model::LocalRef, - ) -> Result<(usize, model::TermId), ImportError> { + ) -> Result<(usize, LocalVar), ImportError> { let term = match local_ref { model::LocalRef::Index(_, index) => self .local_variables .get_index(*index as usize) - .map(|(_, term)| (*index as usize, *term)), + .map(|(_, v)| (*index as usize, *v)), model::LocalRef::Named(name) => self .local_variables .get_full(name) - .map(|(index, _, term)| (index, *term)), + .map(|(index, _, v)| (index, *v)), }; term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into()) @@ -898,41 +898,49 @@ impl<'a> Context<'a> { self.with_local_socpe(|ctx| { let mut imported_params = Vec::with_capacity(decl.params.len()); - for param in decl.params { - // TODO: `PolyFuncType` should be able to handle constraints - // and distinguish between implicit and explicit parameters. - match param { - model::Param::Implicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Explicit { name, r#type } => { - imported_params.push(ctx.import_type_param(*r#type)?); - ctx.local_variables.insert(name, *r#type); - } - model::Param::Constraint { constraint: _ } => { - return Err(error_unsupported!("constraints")); + ctx.local_variables.extend( + decl.params + .iter() + .map(|param| (param.name, LocalVar::new(param.r#type))), + ); + + for constraint in decl.constraints { + match ctx.get_term(*constraint)? { + model::Term::NonLinearConstraint { term } => { + let model::Term::Var(var) = ctx.get_term(*term)? else { + return Err(error_unsupported!( + "constraint on term that is not a variable" + )); + }; + + let var = ctx.resolve_local_ref(var)?.0; + ctx.local_variables[var].bound = TypeBound::Copyable; } + _ => return Err(error_unsupported!("constraint other than copy or discard")), } } + for (index, param) in decl.params.iter().enumerate() { + // NOTE: `PolyFuncType` only has explicit type parameters at present. + let bound = ctx.local_variables[index].bound; + imported_params.push(ctx.import_type_param(param.r#type, bound)?); + } + let body = ctx.import_func_type::(decl.signature)?; in_scope(ctx, PolyFuncTypeBase::new(imported_params, body)) }) } /// Import a [`TypeParam`] from a term that represents a static type. - fn import_type_param(&mut self, term_id: model::TermId) -> Result { + fn import_type_param( + &mut self, + term_id: model::TermId, + bound: TypeBound, + ) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), - model::Term::Type => { - // As part of the migration from `TypeBound`s to constraints, we pretend that all - // `TypeBound`s are copyable. - Ok(TypeParam::Type { - b: TypeBound::Copyable, - }) - } + model::Term::Type => Ok(TypeParam::Type { b: bound }), model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")), @@ -944,7 +952,9 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { - let param = Box::new(self.import_type_param(*item_type)?); + // At present `hugr-model` has no way to express that the item + // type of a list must be copyable. Therefore we import it as `Any`. + let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?); Ok(TypeParam::List { param }) } @@ -958,7 +968,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::ExtSet { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -966,7 +979,7 @@ impl<'a> Context<'a> { } } - /// Import a `TypeArg` froma term that represents a static type or value. + /// Import a `TypeArg` from a term that represents a static type or value. fn import_type_arg(&mut self, term_id: model::TermId) -> Result { match self.get_term(term_id)? { model::Term::Wildcard => Err(error_uninferred!("wildcard")), @@ -975,8 +988,8 @@ impl<'a> Context<'a> { } model::Term::Var(var) => { - let (index, var_type) = self.resolve_local_ref(var)?; - let decl = self.import_type_param(var_type)?; + let (index, var) = self.resolve_local_ref(var)?; + let decl = self.import_type_param(var.r#type, var.bound)?; Ok(TypeArg::new_var_use(index, decl)) } @@ -1014,7 +1027,10 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Control { .. } + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1115,7 +1131,10 @@ impl<'a> Context<'a> { | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType - | model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()), + | model::Term::Nat(_) + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } } } @@ -1291,3 +1310,21 @@ impl<'a> Names<'a> { Ok(Self { items }) } } + +/// Information about a local variable. +#[derive(Debug, Clone, Copy)] +struct LocalVar { + /// The type of the variable. + r#type: model::TermId, + /// The type bound of the variable. + bound: TypeBound, +} + +impl LocalVar { + pub fn new(r#type: model::TermId) -> Self { + Self { + r#type, + bound: TypeBound::Any, + } + } +} diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 611eda660..d9ef0d2c9 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() { "../../hugr-model/tests/fixtures/model-params.edn" ))); } + +#[test] +pub fn test_roundtrip_constraints() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-constraints.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap new file mode 100644 index 000000000..f085c4785 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -0,0 +1,16 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-constraints.edn\"))" +--- +(hugr 0) + +(declare-func array.replicate + (forall ?0 type) + (forall ?1 nat) + (where (nonlinear ?0)) + [?0] [(@ array.Array ?0 ?1)] (ext)) + +(declare-func array.copy + (forall ?0 type) + (where (nonlinear ?0)) + [(@ array.Array ?0)] [(@ array.Array ?0) (@ array.Array ?0)] (ext)) diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 95db81205..94341beba 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -56,13 +56,15 @@ struct Operation { struct FuncDefn { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct FuncDecl { name @0 :Text; params @1 :List(Param); - signature @2 :TermId; + constraints @2 :List(TermId); + signature @3 :TermId; } struct AliasDefn { @@ -81,13 +83,15 @@ struct Operation { struct ConstructorDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } struct OperationDecl { name @0 :Text; params @1 :List(Param); - type @2 :TermId; + constraints @2 :List(TermId); + type @3 :TermId; } } @@ -157,6 +161,7 @@ struct Term { funcType @17 :FuncType; control @18 :TermId; controlType @19 :Void; + nonLinearConstraint @20 :TermId; } struct Apply { @@ -187,19 +192,12 @@ struct Term { } struct Param { - union { - implicit @0 :Implicit; - explicit @1 :Explicit; - constraint @2 :TermId; - } - - struct Implicit { - name @0 :Text; - type @1 :TermId; - } + name @0 :Text; + type @1 :TermId; + sort @2 :ParamSort; +} - struct Explicit { - name @0 :Text; - type @1 :TermId; - } +enum ParamSort { + implicit @0; + explicit @1; } diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 681bd4ea9..5381a7dc8 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -140,10 +140,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DefineFunc { decl } @@ -152,10 +154,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let signature = model::TermId(reader.get_signature()); let decl = bump.alloc(model::FuncDecl { name, params, + constraints, signature, }); model::Operation::DeclareFunc { decl } @@ -189,10 +193,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::ConstructorDecl { name, params, + constraints, r#type, }); model::Operation::DeclareConstructor { decl } @@ -201,10 +207,12 @@ fn read_operation<'a>( let reader = reader?; let name = bump.alloc_str(reader.get_name()?.to_str()?); let params = read_list!(bump, reader, get_params, read_param); + let constraints = read_scalar_list!(bump, reader, get_constraints, model::TermId); let r#type = model::TermId(reader.get_type()); let decl = bump.alloc(model::OperationDecl { name, params, + constraints, r#type, }); model::Operation::DeclareOperation { decl } @@ -332,6 +340,10 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Control(values) => model::Term::Control { values: model::TermId(values), }, + + Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { + term: model::TermId(term), + }, }) } @@ -348,23 +360,13 @@ fn read_param<'a>( bump: &'a Bump, reader: hugr_capnp::param::Reader, ) -> ReadResult> { - use hugr_capnp::param::Which; - Ok(match reader.which()? { - Which::Implicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Implicit { name, r#type } - } - Which::Explicit(reader) => { - let reader = reader?; - let name = bump.alloc_str(reader.get_name()?.to_str()?); - let r#type = model::TermId(reader.get_type()); - model::Param::Explicit { name, r#type } - } - Which::Constraint(constraint) => { - let constraint = model::TermId(constraint); - model::Param::Constraint { constraint } - } - }) + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let r#type = model::TermId(reader.get_type()); + + let sort = match reader.get_sort()? { + hugr_capnp::ParamSort::Implicit => model::ParamSort::Implicit, + hugr_capnp::ParamSort::Explicit => model::ParamSort::Explicit, + }; + + Ok(model::Param { name, r#type, sort }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index a4b64d646..f3a0a14d2 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -60,12 +60,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_func_defn(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } model::Operation::DeclareFunc { decl } => { let mut builder = builder.init_func_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_signature(decl.signature.0); } @@ -87,12 +89,14 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode let mut builder = builder.init_constructor_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } model::Operation::DeclareOperation { decl } => { let mut builder = builder.init_operation_decl(); builder.set_name(decl.name); write_list!(builder, init_params, write_param, decl.params); + let _ = builder.set_constraints(model::TermId::unwrap_slice(decl.constraints)); builder.set_type(decl.r#type.0); } @@ -101,19 +105,12 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } fn write_param(mut builder: hugr_capnp::param::Builder, param: &model::Param) { - match param { - model::Param::Implicit { name, r#type } => { - let mut builder = builder.init_implicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Explicit { name, r#type } => { - let mut builder = builder.init_explicit(); - builder.set_name(name); - builder.set_type(r#type.0); - } - model::Param::Constraint { constraint } => builder.set_constraint(constraint.0), - } + builder.set_name(param.name); + builder.set_type(param.r#type.0); + builder.set_sort(match param.sort { + model::ParamSort::Implicit => hugr_capnp::ParamSort::Implicit, + model::ParamSort::Explicit => hugr_capnp::ParamSort::Explicit, + }); } fn write_global_ref(mut builder: hugr_capnp::global_ref::Builder, global_ref: &model::GlobalRef) { @@ -212,5 +209,9 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_outputs(outputs.0); builder.set_extensions(extensions.0); } + + model::Term::NonLinearConstraint { term } => { + builder.set_non_linear_constraint(term.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index cb8713b32..16c7cb6c6 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -397,6 +397,8 @@ pub struct FuncDecl<'a> { pub name: &'a str, /// The static parameters of the function. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The signature of the function. pub signature: TermId, } @@ -419,6 +421,8 @@ pub struct ConstructorDecl<'a> { pub name: &'a str, /// The static parameters of the constructor. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the constructed term. pub r#type: TermId, } @@ -430,6 +434,8 @@ pub struct OperationDecl<'a> { pub name: &'a str, /// The static parameters of the operation. pub params: &'a [Param<'a>], + /// The constraints on the static parameters. + pub constraints: &'a [TermId], /// The type of the operation. This must be a function type. pub r#type: TermId, } @@ -662,6 +668,12 @@ pub enum Term<'a> { /// /// `ctrl : static` ControlType, + + /// Constraint that requires a runtime type to be copyable and discardable. + NonLinearConstraint { + /// The runtime type that must be copyable and discardable. + term: TermId, + }, } /// A parameter to a function or alias. @@ -669,33 +681,23 @@ pub enum Term<'a> { /// Parameter names must be unique within a parameter list. /// Implicit and explicit parameters share a namespace. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum Param<'a> { - /// An implicit parameter that should be inferred, unless a full application form is used +pub struct Param<'a> { + /// The name of the parameter. + pub name: &'a str, + /// The type of the parameter. + pub r#type: TermId, + /// The sort of the parameter (implicit or explicit). + pub sort: ParamSort, +} + +/// The sort of a parameter (implicit or explicit). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ParamSort { + /// The parameter is implicit and should be inferred, unless a full application form is used /// (see [`Term::ApplyFull`] and [`Operation::CustomFull`]). - Implicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// An explicit parameter that should always be provided. - Explicit { - /// The name of the parameter. - name: &'a str, - /// The type of the parameter. - /// - /// This must be a term of type `static`. - r#type: TermId, - }, - /// A constraint that should be satisfied by other parameters in a parameter list. - Constraint { - /// The constraint to be satisfied. - /// - /// This must be a term of type `constraint`. - constraint: TermId, - }, + Implicit, + /// The parameter is explicit and should always be provided. + Explicit, } /// Errors that can occur when traversing and interpreting the model. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 132d78567..d05e3d774 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -56,16 +56,16 @@ node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ term ~ term ~ term } +func_header = { symbol ~ param* ~ where_clause* ~ term ~ term ~ term } alias_header = { symbol ~ param* ~ term } -ctr_header = { symbol ~ param* ~ term } -operation_header = { symbol ~ param* ~ term } +ctr_header = { symbol ~ param* ~ where_clause* ~ term } +operation_header = { symbol ~ param* ~ where_clause* ~ term } -param = { param_implicit | param_explicit | param_constraint } +param = { param_implicit | param_explicit } -param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } -param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } -param_constraint = { "(" ~ "where" ~ term ~ ")" } +param_implicit = { "(" ~ "forall" ~ term_var ~ term ~ ")" } +param_explicit = { "(" ~ "param" ~ term_var ~ term ~ ")" } +where_clause = { "(" ~ "where" ~ term ~ ")" } region = { region_dfg | region_cfg } region_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ node* ~ ")" } @@ -92,6 +92,7 @@ term = { | term_ctrl_type | term_apply_full | term_apply + | term_non_linear } term_wildcard = { "_" } @@ -114,3 +115,4 @@ term_adt = { "(" ~ "adt" ~ term ~ ")" } term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } +term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index fa486454b..370dbeac0 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -7,7 +7,7 @@ use thiserror::Error; use crate::v0::{ AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, - NodeId, Operation, OperationDecl, Param, Region, RegionId, RegionKind, Term, TermId, + NodeId, Operation, OperationDecl, Param, ParamSort, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -209,6 +209,11 @@ impl<'a> ParseContext<'a> { Term::Control { values } } + Rule::term_non_linear => { + let term = self.parse_term(inner.next().unwrap())?; + Term::NonLinearConstraint { term } + } + r => unreachable!("term: {:?}", r), }; @@ -544,6 +549,7 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let inputs = self.parse_term(inner.next().unwrap())?; let outputs = self.parse_term(inner.next().unwrap())?; @@ -559,6 +565,7 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc(FuncDecl { name, params, + constraints, signature: func, })) } @@ -584,11 +591,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(ConstructorDecl { name, params, + constraints, r#type, })) } @@ -599,11 +608,13 @@ impl<'a> ParseContext<'a> { let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; let params = self.parse_params(&mut inner)?; + let constraints = self.parse_constraints(&mut inner)?; let r#type = self.parse_term(inner.next().unwrap())?; Ok(self.bump.alloc(OperationDecl { name, params, + constraints, r#type, })) } @@ -619,18 +630,21 @@ impl<'a> ParseContext<'a> { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Implicit { name, r#type } + Param { + name, + r#type, + sort: ParamSort::Implicit, + } } Rule::param_explicit => { let mut inner = param.into_inner(); let name = &inner.next().unwrap().as_str()[1..]; let r#type = self.parse_term(inner.next().unwrap())?; - Param::Explicit { name, r#type } - } - Rule::param_constraint => { - let mut inner = param.into_inner(); - let constraint = self.parse_term(inner.next().unwrap())?; - Param::Constraint { constraint } + Param { + name, + r#type, + sort: ParamSort::Explicit, + } } _ => unreachable!(), }; @@ -641,6 +655,17 @@ impl<'a> ParseContext<'a> { Ok(self.bump.alloc_slice_copy(¶ms)) } + fn parse_constraints(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [TermId]> { + let mut constraints = Vec::new(); + + for pair in filter_rule(pairs, Rule::where_clause) { + let constraint = self.parse_term(pair.into_inner().next().unwrap())?; + constraints.push(constraint); + } + + Ok(self.bump.alloc_slice_copy(&constraints)) + } + fn parse_signature(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult> { let Some(Rule::signature) = pairs.peek().map(|p| p.as_rule()) else { return Ok(None); diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 01b9d7195..512f6d1e4 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -2,8 +2,8 @@ use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ - GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, RegionId, - RegionKind, Term, TermId, + GlobalRef, LinkRef, LocalRef, MetaItem, ModelError, Module, NodeId, Operation, Param, + ParamSort, RegionId, RegionKind, Term, TermId, }; type PrintError = ModelError; @@ -122,15 +122,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { f: impl FnOnce(&mut Self) -> PrintResult, ) -> PrintResult { let locals = std::mem::take(&mut self.locals); - - for param in params { - match param { - Param::Implicit { name, .. } => self.locals.push(name), - Param::Explicit { name, .. } => self.locals.push(name), - Param::Constraint { .. } => {} - } - } - + self.locals.extend(params.iter().map(|param| param.name)); let result = f(self); self.locals = locals; result @@ -178,9 +170,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -208,9 +199,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; match self.module.get_term(decl.signature) { Some(Term::FuncType { @@ -303,9 +293,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_term(*value)?; @@ -318,9 +306,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -333,9 +319,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -348,9 +333,8 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(decl.name); }); - for param in decl.params { - this.print_param(*param)?; - } + this.print_params(decl.params)?; + this.print_constraints(decl.constraints)?; this.print_term(decl.r#type)?; this.print_meta(node_data.meta)?; @@ -384,10 +368,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } fn print_regions(&mut self, regions: &'a [RegionId]) -> PrintResult<()> { - for region in regions { - self.print_region(*region)?; - } - Ok(()) + regions + .iter() + .try_for_each(|region| self.print_region(*region)) } fn print_region(&mut self, region: RegionId) -> PrintResult<()> { @@ -422,11 +405,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { .get_region(region) .ok_or(PrintError::RegionNotFound(region))?; - for node_id in region_data.children { - self.print_node(*node_id)?; - } - - Ok(()) + region_data + .children + .iter() + .try_for_each(|node_id| self.print_node(*node_id)) } fn print_port_lists( @@ -460,25 +442,33 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } } + fn print_params(&mut self, params: &'a [Param<'a>]) -> PrintResult<()> { + params.iter().try_for_each(|param| self.print_param(*param)) + } + fn print_param(&mut self, param: Param<'a>) -> PrintResult<()> { - self.print_parens(|this| match param { - Param::Implicit { name, r#type } => { - this.print_text("forall"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Explicit { name, r#type } => { - this.print_text("param"); - this.print_text(format!("?{}", name)); - this.print_term(r#type) - } - Param::Constraint { constraint } => { - this.print_text("where"); - this.print_term(constraint) - } + self.print_parens(|this| { + match param.sort { + ParamSort::Implicit => this.print_text("forall"), + ParamSort::Explicit => this.print_text("param"), + }; + + this.print_text(format!("?{}", param.name)); + this.print_term(param.r#type) }) } + fn print_constraints(&mut self, terms: &'a [TermId]) -> PrintResult<()> { + for term in terms { + self.print_parens(|this| { + this.print_text("where"); + this.print_term(*term) + })?; + } + + Ok(()) + } + fn print_term(&mut self, term_id: TermId) -> PrintResult<()> { let term_data = self .module @@ -598,6 +588,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("ctrl"); Ok(()) } + Term::NonLinearConstraint { term } => self.print_parens(|this| { + this.print_text("nonlinear"); + this.print_term(*term) + }), } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 043061677..93955fe6e 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -51,3 +51,8 @@ pub fn test_params() { pub fn test_decl_exts() { binary_roundtrip(include_str!("fixtures/model-decl-exts.edn")); } + +#[test] +pub fn test_constraints() { + binary_roundtrip(include_str!("fixtures/model-constraints.edn")); +} diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn new file mode 100644 index 000000000..5db6b9886 --- /dev/null +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -0,0 +1,13 @@ +(hugr 0) + +(declare-func array.replicate + (forall ?t type) + (forall ?n nat) + (where (nonlinear ?t)) + [?t] [(@ array.Array ?t ?n)] + (ext)) + +(declare-func array.copy + (forall ?t type) + (where (nonlinear ?t)) + [(@ array.Array ?t)] [(@ array.Array ?t) (@ array.Array ?t)] (ext))