diff --git a/hir-analysis/src/lib.rs b/hir-analysis/src/lib.rs index 9a3a6e02a..2b0dcb827 100644 --- a/hir-analysis/src/lib.rs +++ b/hir-analysis/src/lib.rs @@ -2,11 +2,13 @@ mod control_flow; mod dominance; mod liveness; mod loops; +mod validation; pub use self::control_flow::{BlockPredecessor, ControlFlowGraph}; pub use self::dominance::{DominanceFrontier, DominatorTree, DominatorTreePreorder}; pub use self::liveness::LivenessAnalysis; pub use self::loops::{Loop, LoopAnalysis, LoopLevel}; +pub use self::validation::{ModuleValidator, Rule}; use anyhow::anyhow; @@ -164,11 +166,15 @@ impl FunctionAnalysis { pub fn cfg_changed(&mut self, function: &miden_hir::Function) { // If the dominator tree hasn't been computed, no other // analyses could possibly have been computed yet. - let Some(domtree) = self.domtree.as_mut() else { return; }; + let Some(domtree) = self.domtree.as_mut() else { + return; + }; domtree.compute(function, &self.cfg); // Likewise for loop analysis - we can't compute liveness without it - let Some(loops) = self.loops.as_mut() else { return; }; + let Some(loops) = self.loops.as_mut() else { + return; + }; loops.compute(function, &self.cfg, domtree); if let Some(liveness) = self.liveness.as_mut() { diff --git a/hir-analysis/src/validation/block.rs b/hir-analysis/src/validation/block.rs new file mode 100644 index 000000000..9edad90e4 --- /dev/null +++ b/hir-analysis/src/validation/block.rs @@ -0,0 +1,243 @@ +use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Spanned}; +use miden_hir::*; +use rustc_hash::FxHashSet; +use smallvec::SmallVec; + +use super::{Rule, ValidationError}; +use crate::DominatorTree; + +/// This validation rule ensures that all values definitions dominate their uses. +/// +/// For example, it is not valid to use a value in a block when its definition only +/// occurs along a subset of control flow paths which may be taken to that block. +/// +/// This also catches uses of values which are orphaned (i.e. they are defined by +/// a block parameter or instruction which is not attached to the function). +pub struct DefsDominateUses<'a> { + dfg: &'a DataFlowGraph, + domtree: &'a DominatorTree, +} +impl<'a> DefsDominateUses<'a> { + pub fn new(dfg: &'a DataFlowGraph, domtree: &'a DominatorTree) -> Self { + Self { dfg, domtree } + } +} +impl<'a> Rule for DefsDominateUses<'a> { + fn validate( + &mut self, + block_data: &BlockData, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + let current_block = block_data.id; + let mut uses = FxHashSet::::default(); + let mut defs = FxHashSet::::default(); + for node in block_data.insts.iter() { + let span = node.span(); + + uses.clear(); + defs.clear(); + + // Verify the integrity of the instruction results + for result in self.dfg.inst_results(node.key) { + // It should never be possible for a value id to be present in the result set twice + assert!(defs.insert(*result)); + } + + // Gather all value uses to check + uses.extend(node.arguments(&self.dfg.value_lists).iter().copied()); + match node.analyze_branch(&self.dfg.value_lists) { + BranchInfo::NotABranch => (), + BranchInfo::SingleDest(_, args) => { + uses.extend(args.iter().copied()); + } + BranchInfo::MultiDest(ref jts) => { + for jt in jts.iter() { + uses.extend(jt.args.iter().copied()); + } + } + } + + // Make sure there are no uses of the instructions own results + if !defs.is_disjoint(&uses) { + invalid_instruction!( + diagnostics, + node.key, + span, + "an instruction may not use its own results as arguments", + "This situation can only arise if one has manually modified the arguments of an instruction, \ + incorrectly inserting a value obtained from the set of instruction results." + ); + } + + // Next, ensure that all used values are dominated by their definition + for value in uses.iter().copied() { + match self.dfg.value_data(value) { + // If the value comes from the current block's parameter list, this use is trivially dominated + ValueData::Param { block, .. } if block == ¤t_block => continue, + // If the value comes from another block, then as long as all paths to the current + // block flow through that block, then this use is dominated by its definition + ValueData::Param { block, .. } => { + if self.domtree.dominates(*block, current_block, &self.dfg) { + continue; + } + } + // If the value is an instruction result, then as long as all paths to the current + // instruction flow through the defining instruction, then this use is dominated + // by its definition + ValueData::Inst { inst, .. } => { + if self.domtree.dominates(*inst, node.key, &self.dfg) { + continue; + } + } + } + + // If we reach here, the use of `value` is not dominated by its definition, + // so this use is invalid + invalid_instruction!( + diagnostics, + node.key, + span, + "an argument of this instruction, {value}, is not defined on all paths leading to this point", + "All uses of a value must be dominated by its definition, i.e. all control flow paths \ + from the function entry to the point of each use must flow through the point where \ + that value is defined." + ); + } + } + + Ok(()) + } +} + +/// This validation rule ensures that most block-local invariants are upheld: +/// +/// * A block may not be empty +/// * A block must end with a terminator instruction +/// * A block may not contain a terminator instruction in any position but the end +/// * A block which terminates with a branch instruction must reference a block +/// that is present in the function body (i.e. it is not valid to reference +/// detached blocks) +/// * A multi-way branch instruction must have at least one successor +/// * A multi-way branch instruction must not specify the same block as a successor multiple times. +/// +/// This rule does not perform type checking, or verify use/def dominance. +pub struct BlockValidator<'a> { + dfg: &'a DataFlowGraph, + span: SourceSpan, +} +impl<'a> BlockValidator<'a> { + pub fn new(dfg: &'a DataFlowGraph, span: SourceSpan) -> Self { + Self { dfg, span } + } +} +impl<'a> Rule for BlockValidator<'a> { + fn validate( + &mut self, + block_data: &BlockData, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + // Ignore blocks which are not attached to the function body + if !block_data.link.is_linked() { + return Ok(()); + } + + // Ensure there is a terminator, and that it is valid + let id = block_data.id; + let terminator = block_data.insts.back().get(); + if terminator.is_none() { + // This block is empty + invalid_block!( + diagnostics, + id, + self.span, + "block cannot be empty", + "Empty blocks are only valid when detached from the function body" + ); + } + + let terminator = terminator.unwrap(); + let op = terminator.opcode(); + if !op.is_terminator() { + invalid_block!( + diagnostics, + id, + self.span, + "invalid block terminator", + format!("The last instruction in a block must be a terminator, but {id} ends with {op} which is not a valid terminator") + ); + } + match terminator.analyze_branch(&self.dfg.value_lists) { + BranchInfo::SingleDest(destination, _) => { + let dest = self.dfg.block(destination); + if !dest.link.is_linked() { + invalid_instruction!( + diagnostics, + terminator.key, + terminator.span(), + "invalid successor", + format!("A block reference is only valid if the referenced block is present in the function layout. \ + {id} references {destination}, but the latter is not in the layout") + ); + } + } + BranchInfo::MultiDest(ref jts) => { + if jts.is_empty() { + invalid_instruction!( + diagnostics, + terminator.key, + terminator.span(), + "incomplete {op} instruction", + "This instruction normally has 2 or more successors, but none were given." + ); + } + + let mut seen = SmallVec::<[Block; 4]>::default(); + for jt in jts.iter() { + let dest = self.dfg.block(jt.destination); + let destination = jt.destination; + if !dest.link.is_linked() { + invalid_instruction!( + diagnostics, + terminator.key, + terminator.span(), + "invalid successor", + format!("A block reference is only valid if the referenced block is present in the function layout. \ + {id} references {destination}, but the latter is not in the layout") + ); + } + + if seen.contains(&jt.destination) { + invalid_instruction!( + diagnostics, + terminator.key, + terminator.span(), + "invalid {op} instruction", + format!("A given block may only be a successor along a single control flow path, \ + but {id} uses {destination} as a successor for more than one path") + ); + } + + seen.push(jt.destination); + } + } + BranchInfo::NotABranch => (), + } + + // Verify that there are no terminator instructions in any other position than last + for node in block_data.insts.iter() { + let op = node.opcode(); + if op.is_terminator() && node.key != terminator.key { + invalid_block!( + diagnostics, + id, + self.span, + "terminator found in middle of block", + format!("A block may only have a terminator instruction as the last instruction in the block, \ + but {id} uses {op} before the end of the block") + ); + } + } + + Ok(()) + } +} diff --git a/hir-analysis/src/validation/function.rs b/hir-analysis/src/validation/function.rs new file mode 100644 index 000000000..0e2b87309 --- /dev/null +++ b/hir-analysis/src/validation/function.rs @@ -0,0 +1,318 @@ +use miden_diagnostics::{DiagnosticsHandler, Severity, Spanned}; +use miden_hir::*; + +use super::{ + BlockValidator, DefsDominateUses, NamingConventions, Rule, TypeCheck, ValidationError, +}; +use crate::{ControlFlowGraph, DominatorTree}; + +/// This validation rule ensures that function-local invariants are upheld: +/// +/// * A function may not be empty +/// * All blocks in the function body must be valid +/// * All uses of values must be dominated by their definitions +/// * All value uses must type check, i.e. branching to a block with a value +/// of a different type than declared by the block parameter is invalid. +pub struct FunctionValidator { + in_kernel_module: bool, +} +impl FunctionValidator { + pub fn new(in_kernel_module: bool) -> Self { + Self { in_kernel_module } + } +} +impl Rule for FunctionValidator { + fn validate( + &mut self, + function: &Function, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + // Validate the function declaration + let mut rules = NamingConventions.chain(CoherentSignature::new(self.in_kernel_module)); + rules.validate(function, diagnostics)?; + + // Ensure basic integrity of the function body + let mut rules = BlockValidator::new(&function.dfg, function.id.span()); + for (_, block) in function.dfg.blocks() { + rules.validate(block, diagnostics)?; + } + + // Construct control flow and dominator tree analyses + let cfg = ControlFlowGraph::with_function(function); + let domtree = DominatorTree::with_function(function, &cfg); + + // Verify value usage + let mut rules = DefsDominateUses::new(&function.dfg, &domtree) + .chain(TypeCheck::new(&function.signature, &function.dfg)); + for (_, block) in function.dfg.blocks() { + rules.validate(block, diagnostics)?; + } + + Ok(()) + } +} + +/// This validation rule ensures that a [Signature] is coherent +/// +/// A signature is coherent if: +/// +/// 1. The linkage is valid for functions +/// 2. The calling convention is valid in the context the function is defined in +/// 3. The ABI of its parameters matches the calling convention +/// 4. The ABI of the parameters and results are coherent, e.g. +/// there are no signed integer parameters which are specified +/// as being zero-extended, there are no results if an sret +/// parameter is present, etc. +struct CoherentSignature { + in_kernel_module: bool, +} +impl CoherentSignature { + pub fn new(in_kernel_module: bool) -> Self { + Self { in_kernel_module } + } +} + +impl Rule for CoherentSignature { + fn validate( + &mut self, + function: &Function, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + let span = function.id.span(); + + // 1 + let linkage = function.signature.linkage; + if !matches!(linkage, Linkage::External | Linkage::Internal) { + invalid_function!( + diagnostics, + function.id, + "the signature of this function specifies '{linkage}' linkage, \ + but only 'external' or 'internal' are valid" + ); + } + + // 2 + let cc = function.signature.cc; + let is_kernel_function = matches!(cc, CallConv::Kernel); + if self.in_kernel_module { + let is_public = function.signature.is_public(); + if is_public && !is_kernel_function { + invalid_function!( + diagnostics, + function.id, + function.id.span(), + "the '{cc}' calling convention may only be used with \ + 'internal' linkage in kernel modules", + "This function is declared with 'external' linkage in a kernel module, so \ + it must use the 'kernel' calling convention" + ); + } else if !is_public && is_kernel_function { + invalid_function!( + diagnostics, + function.id, + function.id.span(), + "the 'kernel' calling convention may only be used with 'external' linkage", + "This function has 'internal' linkage, so it must either be made 'external', \ + or a different calling convention must be used" + ); + } + } else if is_kernel_function { + invalid_function!( + diagnostics, + function.id, + function.id.span(), + "the 'kernel' calling convention may only be used in kernel modules", + "Kernel functions may only be declared in kernel modules, so you must either \ + change the module type, or change the calling convention of this function" + ); + } + + // 3 + // * sret parameters may not be used with kernel calling convention + // * pointer-typed parameters/results may not be used with kernel calling convention + // * parameters larger than 8 bytes must be passed by reference with fast/C calling conventions + // * results larger than 8 bytes require the use of an sret parameter with fast/C calling conventions + // * total size of all parameters when laid out on the operand stack may not exceed 64 bytes (16 field elements) + // + // 4 + // * paramter count and types must be consistent between the signature and the entry block + // * only sret parameter is permitted, and it must be the first parameter when present + // * the sret attribute may not be applied to results + // * sret parameters imply no results + // * signed integer values may not be combined with zero-extension + // * non-integer values may not be combined with argument extension + let mut sret_count = 0; + let mut effective_stack_usage = 0; + let params = function.dfg.block_args(function.dfg.entry_block()); + if params.len() != function.signature.arity() { + invalid_function!( + diagnostics, + function.id, + function.id.span(), + "function signature and entry block have different arities", + "This happens if the signature or entry block are modified without updating the other, \ + make sure the number and types of all parameters are the same in both the signature and \ + the entry block" + ); + } + for (i, param) in function.signature.params.iter().enumerate() { + let is_first = i == 0; + let value = params[i]; + let span = function.dfg.value_span(value); + let param_ty = ¶m.ty; + let value_ty = function.dfg.value_type(value); + + if param_ty != value_ty { + invalid_function!( + diagnostics, + function.id, + span, + "parameter type mismatch between signature and entry block", + format!( + "The function declares this parameter as having type {param_ty}, \ + but the actual type is {value_ty}" + ) + ); + } + + let is_integer = param_ty.is_integer(); + let is_signed_integer = param_ty.is_signed_integer(); + match param.extension { + ArgumentExtension::Zext if is_signed_integer => { + invalid_function!( + diagnostics, + function.id, + span, + "signed integer parameters may not be combined with zero-extension", + "Zero-extending a signed-integer loses the signedness, you should use signed-extension instead" + ); + } + ArgumentExtension::Sext | ArgumentExtension::Zext if !is_integer => { + invalid_function!( + diagnostics, + function.id, + span, + "non-integer parameters may not be combined with argument extension attributes", + "Argument extension has no meaning for types other than integers" + ); + } + _ => (), + } + + let is_pointer = param_ty.is_pointer(); + let is_sret = param.purpose == ArgumentPurpose::StructReturn; + if is_sret { + sret_count += 1; + } + + if is_kernel_function && (is_sret || is_pointer) { + invalid_function!( + diagnostics, + function.id, + span, + "functions using the 'kernel' calling convention may not use sret or pointer-typed parameters", + "Kernel functions are invoked in a different memory context, so they may not pass or return values by reference" + ); + } + + if !is_kernel_function { + if is_sret { + if sret_count > 1 || !is_first { + invalid_function!( + diagnostics, + function.id, + span, + "a function may only have a single sret parameter, and it must be the first parameter", + "The sret parameter type is used to return a large value from a function, \ + but it may only be used for functions with a single return value" + ); + } + if !is_pointer { + invalid_function!( + diagnostics, + function.id, + span, + "sret parameters must be pointer-typed, but got {param_ty}", + format!( + "Did you mean to define this parameter with type {}?", + &Type::Ptr(Box::new(param_ty.clone())) + ) + ); + } + + if !function.signature.results.is_empty() { + invalid_function!( + diagnostics, + function.id, + span, + "functions with an sret parameter must have no results", + "An sret parameter is used in place of normal return values, but this function uses both, \ + which is not valid. You should remove the results from the function signature." + ); + } + } + + let size_in_bytes = param_ty.size_in_bytes(); + if !is_pointer && size_in_bytes > 8 { + invalid_function!( + diagnostics, + function.id, + span, + "this parameter type is too large to pass by value", + format!("This parameter has type {param_ty}, you must refactor this function to pass it by reference instead") + ); + } + } + + effective_stack_usage += param_ty + .clone() + .to_raw_parts() + .map(|parts| parts.len()) + .unwrap_or(0); + } + + if effective_stack_usage > 16 { + invalid_function!( + diagnostics, + function.id, + span, + "this function has a signature with too many parameters", + "Due to the constraints of the Miden VM, all function parameters must fit on the operand stack, \ + which is 16 elements (each of which is effectively 4 bytes, a maximum of 64 bytes). \ + The layout of the parameter list of this function requires more than this limit. \ + You should either remove parameters, or combine some of them into a struct which is then passed by reference." + ); + } + + for (i, result) in function.signature.results.iter().enumerate() { + if result.purpose == ArgumentPurpose::StructReturn { + invalid_function!( + diagnostics, + function.id, + "the sret attribute is only permitted on function parameters" + ); + } + + if result.extension != ArgumentExtension::None { + invalid_function!( + diagnostics, + function.id, + "the argument extension attributes are only permitted on function parameters" + ); + } + + let size_in_bytes = result.ty.size_in_bytes(); + if !result.ty.is_pointer() && size_in_bytes > 8 { + invalid_function!( + diagnostics, + function.id, + function.id.span(), + "This function specifies a result type which is too large to pass by value", + format!("The parameter at index {} has type {}, you must refactor this function to pass it by reference instead", i, &result.ty) + ); + } + } + + Ok(()) + } +} diff --git a/hir-analysis/src/validation/mod.rs b/hir-analysis/src/validation/mod.rs new file mode 100644 index 000000000..a7de1f929 --- /dev/null +++ b/hir-analysis/src/validation/mod.rs @@ -0,0 +1,469 @@ +macro_rules! bug { + ($diagnostics:ident, $msg:literal) => {{ + diagnostic!($diagnostics, Severity::Bug, $msg); + }}; + + ($diagnostics:ident, $msg:literal, $span:expr, $label:expr) => {{ + diagnostic!($diagnostics, Severity::Bug, $msg, $span, $label); + }}; + + ($diagnostics:ident, $msg:literal, $span:expr, $label:expr, $note:expr) => {{ + diagnostic!($diagnostics, Severity::Bug, $msg, $span, $label, $note); + }}; + + ($diagnostics:ident, $msg:literal, $span:expr, $label:expr, $span2:expr, $label2:expr) => {{ + diagnostic!( + $diagnostics, + Severity::Bug, + $msg, + $span, + $label, + $span2, + $label2 + ); + }}; +} + +macro_rules! error { + ($diagnostics:ident, $msg:literal) => {{ + diagnostic!($diagnostics, Severity::Error, $msg); + }}; + + ($diagnostics:ident, $msg:literal, $span:expr, $label:expr) => {{ + diagnostic!($diagnostics, Severity::Error, $msg, $span, $label); + }}; + + ($diagnostics:ident, $msg:literal, $span:expr, $label:expr, $note:expr) => {{ + diagnostic!($diagnostics, Severity::Error, $msg, $span, $label, $note); + }}; + + ($diagnostics:ident, $msg:literal, $span:expr, $label:expr, $span2:expr, $label2:expr) => {{ + diagnostic!( + $diagnostics, + Severity::Error, + $msg, + $span, + $label, + $span2, + $label2 + ); + }}; +} + +macro_rules! invalid_instruction { + ($diagnostics:ident, $inst:expr, $span:expr, $label:expr) => {{ + let span = $span; + let reason = format!($label); + bug!($diagnostics, "invalid instruction", span, reason.as_str()); + return Err(crate::validation::ValidationError::InvalidInstruction { + span, + inst: $inst, + reason, + }); + }}; + + ($diagnostics:ident, $inst:expr, $span:expr, $label:expr, $note:expr) => {{ + let span = $span; + let reason = format!($label); + bug!( + $diagnostics, + "invalid instruction", + span, + reason.as_str(), + $note + ); + return Err(crate::validation::ValidationError::InvalidInstruction { + span, + inst: $inst, + reason, + }); + }}; +} + +macro_rules! invalid_block { + ($diagnostics:ident, $block:expr, $span:expr, $label:expr) => {{ + let reason = format!($label); + bug!($diagnostics, "invalid block", $span, reason.as_str()); + return Err(crate::validation::ValidationError::InvalidBlock { + block: $block, + reason, + }); + }}; + + ($diagnostics:ident, $block:expr, $span:expr, $label:expr, $note:expr) => {{ + let reason = format!($label); + bug!($diagnostics, "invalid block", $span, reason.as_str(), $note); + return Err(crate::validation::ValidationError::InvalidBlock { + block: $block, + reason, + }); + }}; +} + +macro_rules! invalid_module { + ($diagnostics:ident, $module:expr, $label:expr) => {{ + invalid_module!($diagnostics, $module, $module.span(), $label); + }}; + + ($diagnostics:ident, $module:expr, $span:expr, $label:expr) => {{ + let span = $span; + let reason = format!($label); + error!($diagnostics, "invalid module", span, reason.as_str()); + return Err(crate::validation::ValidationError::InvalidModule { + module: $module, + reason, + }); + }}; + + ($diagnostics:ident, $module:expr, $span:expr, $label:expr, $note:expr) => {{ + let span = $span; + let reason = format!($label); + error!($diagnostics, "invalid module", span, reason.as_str(), $note); + return Err(crate::validation::ValidationError::InvalidModule { + module: $module, + reason, + }); + }}; +} + +macro_rules! invalid_function { + ($diagnostics:ident, $function:expr, $label:expr) => {{ + invalid_function!($diagnostics, $function, $function.span(), $label); + }}; + + ($diagnostics:ident, $function:expr, $span:expr, $label:expr) => {{ + let span = $span; + let reason = format!($label); + error!($diagnostics, "invalid function", span, reason.as_str()); + return Err(crate::validation::ValidationError::InvalidFunction { + function: $function, + reason, + }); + }}; + + ($diagnostics:ident, $function:expr, $span:expr, $label:expr, $note:expr) => {{ + let span = $span; + let reason = format!($label); + error!( + $diagnostics, + "invalid function", + span, + reason.as_str(), + $note + ); + return Err(crate::validation::ValidationError::InvalidFunction { + function: $function, + reason, + }); + }}; + + ($diagnostics:ident, $function:expr, $span:expr, $label:expr, $span2:expr, $label2:expr) => {{ + let span = $span; + let reason = format!($label); + error!($diagnostics, "invalid function", span, reason.as_str()); + $diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("invalid function") + .with_primary_label(span, reason.as_str()) + .with_secondary_label($span2, $label2) + .emit(); + return Err(crate::validation::ValidationError::InvalidFunction { + function: $function, + reason, + }); + }}; +} + +macro_rules! invalid_global { + ($diagnostics:ident, $name:expr, $label:expr) => {{ + invalid_global!($diagnostics, $name, $name.span(), $label); + }}; + + ($diagnostics:ident, $name:expr, $span:expr, $label:expr) => {{ + let span = $span; + let reason = format!($label); + error!( + $diagnostics, + "invalid global variable", + span, + reason.as_str() + ); + return Err(crate::validation::ValidationError::InvalidGlobalVariable { + name: $name, + reason, + }); + }}; +} + +mod block; +mod function; +mod naming; +mod typecheck; + +pub use self::typecheck::TypeError; + +use miden_diagnostics::{DiagnosticsHandler, SourceSpan}; +use miden_hir::*; +use miden_hir_pass::Pass; + +use self::block::{BlockValidator, DefsDominateUses}; +use self::function::FunctionValidator; +use self::naming::NamingConventions; +use self::typecheck::TypeCheck; + +/// This error is produced by validation rules run against the IR +#[derive(Debug, thiserror::Error)] +pub enum ValidationError { + /// A validation rule indicates a module is invalid + #[error("invalid module '{module}': {reason}")] + InvalidModule { module: Ident, reason: String }, + /// A validation rule indicates a global variable is invalid + #[error("invalid global variable '{name}': {reason}")] + InvalidGlobalVariable { name: Ident, reason: String }, + /// A validation rule indicates a function is invalid + #[error("invalid function '{function}': {reason}")] + InvalidFunction { + function: FunctionIdent, + reason: String, + }, + /// A validation rule indicates a block is invalid + #[error("invalid block '{block}': {reason}")] + InvalidBlock { block: Block, reason: String }, + /// A validation rule indicates an instruction is invalid + #[error("invalid instruction '{inst}': {reason}")] + InvalidInstruction { + span: SourceSpan, + inst: Inst, + reason: String, + }, + /// A type error was found + #[error("type error: {0}")] + TypeError(#[from] TypeError), + /// An unknown validation error occurred + #[error(transparent)] + Misc(#[from] anyhow::Error), +} +#[cfg(test)] +impl PartialEq for ValidationError { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Self::InvalidModule { + module: am, + reason: ar, + }, + Self::InvalidModule { + module: bm, + reason: br, + }, + ) => am == bm && ar == br, + ( + Self::InvalidGlobalVariable { + name: an, + reason: ar, + }, + Self::InvalidGlobalVariable { + name: bn, + reason: br, + }, + ) => an == bn && ar == br, + ( + Self::InvalidFunction { + function: af, + reason: ar, + }, + Self::InvalidFunction { + function: bf, + reason: br, + }, + ) => af == bf && ar == br, + ( + Self::InvalidBlock { + block: ab, + reason: ar, + }, + Self::InvalidBlock { + block: bb, + reason: br, + }, + ) => ab == bb && ar == br, + ( + Self::InvalidInstruction { + inst: ai, + reason: ar, + .. + }, + Self::InvalidInstruction { + inst: bi, + reason: br, + .. + }, + ) => ai == bi && ar == br, + (Self::TypeError(a), Self::TypeError(b)) => a == b, + (Self::Misc(a), Self::Misc(b)) => a.to_string() == b.to_string(), + (_, _) => false, + } + } +} + +/// A [Rule] validates some specific type of behavior on an item of type `T` +pub trait Rule { + /// Validate `item`, using `diagnostics` to emit relevant diagnostics. + fn validate( + &mut self, + item: &T, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError>; + + /// Combine two rules into one rule + fn chain(self, rule: R) -> RuleSet + where + Self: Sized, + R: Rule, + { + RuleSet::new(self, rule) + } +} +impl Rule for &mut R +where + R: Rule, +{ + fn validate( + &mut self, + item: &T, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + (*self).validate(item, diagnostics) + } +} +impl Rule for Box +where + R: Rule, +{ + fn validate( + &mut self, + item: &T, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + (**self).validate(item, diagnostics) + } +} +impl Rule for dyn FnMut(&T, &DiagnosticsHandler) -> Result<(), ValidationError> { + #[inline] + fn validate( + &mut self, + item: &T, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + self(item, diagnostics) + } +} + +/// A [RuleSet] is a combination of multiple rules into a single [Rule] +pub struct RuleSet { + a: A, + b: B, +} +impl RuleSet { + fn new(a: A, b: B) -> Self { + Self { a, b } + } +} +impl Copy for RuleSet +where + A: Copy, + B: Copy, +{ +} +impl Clone for RuleSet +where + A: Clone, + B: Clone, +{ + #[inline] + fn clone(&self) -> Self { + Self::new(self.a.clone(), self.b.clone()) + } +} +impl Rule for RuleSet +where + A: Rule, + B: Rule, +{ + fn validate( + &mut self, + item: &T, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + self.a + .validate(item, diagnostics) + .and_then(|_| self.b.validate(item, diagnostics)) + } +} + +/// The [ModuleValidator] can be used to validate and emit diagnostics for a [Module]. +/// +/// It implements [miden_hir_pass::Pass], so can be used as part of a pass pipeline. +/// +/// This validates all rules which apply to items at/within module scope. +pub struct ModuleValidator<'a> { + diagnostics: &'a DiagnosticsHandler, +} +impl<'a> ModuleValidator<'a> { + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } + + pub fn validate(&mut self, module: &Module) -> Result<(), ValidationError> { + self.run(module) + } +} +impl<'p> Pass for ModuleValidator<'p> { + type Input<'a> = &'a Module; + type Output<'a> = (); + type Error = ValidationError; + + fn run<'a>(&mut self, input: Self::Input<'a>) -> Result, Self::Error> { + // Apply module-scoped rules + let mut rules = NamingConventions; + rules.validate(input, self.diagnostics)?; + + // Apply global-scoped rules + let mut rules = NamingConventions; + for global in input.globals() { + rules.validate(global, self.diagnostics)?; + } + + // Apply function-scoped rules + let mut rules = FunctionValidator::new(input.is_kernel()); + for function in input.functions() { + rules.validate(function, self.diagnostics)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use miden_hir::{ + testing::{self, TestContext}, + ModuleBuilder, + }; + + use super::*; + + #[test] + fn module_validator_test() { + let context = TestContext::default(); + + // Define the 'test' module + let mut builder = ModuleBuilder::new("test"); + builder.with_span(context.current_span()); + testing::sum_matrix(&mut builder, &context); + let module = builder.build(); + + let mut validator = ModuleValidator::new(&context.diagnostics); + assert_eq!(validator.validate(&module), Ok(())); + } +} diff --git a/hir-analysis/src/validation/naming.rs b/hir-analysis/src/validation/naming.rs new file mode 100644 index 000000000..477a5520f --- /dev/null +++ b/hir-analysis/src/validation/naming.rs @@ -0,0 +1,258 @@ +use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Spanned}; +use miden_hir::*; + +use super::{Rule, ValidationError}; + +/// This validation rule ensures that all identifiers adhere to the rules of their respective items. +pub struct NamingConventions; +impl Rule for NamingConventions { + fn validate( + &mut self, + module: &Module, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + // Make sure all functions in this module have the same module name in their id + for function in module.functions() { + let id = function.id; + if id.module != module.name { + let expected_name = FunctionIdent { + module: module.name, + function: id.function, + }; + invalid_function!( + diagnostics, + function.id, + function.id.span(), + "the fully-qualified name of this function is '{id}'", + module.name.span(), + format!("but we expected '{expected_name}' because it belongs to this module") + ); + } + } + + // 1. Must not be empty + let name = module.name.as_str(); + if name.is_empty() { + invalid_module!(diagnostics, module.name, "module name cannot be empty"); + } + + // 2. Must begin with a lowercase ASCII alphabetic character + if !name.starts_with(is_lower_ascii_alphabetic) { + invalid_module!( + diagnostics, + module.name, + "module name must start with a lowercase, ascii-alphabetic character" + ); + } + + // 3. May otherwise consist of any number of characters of the following classes: + // * `A-Z` + // * `a-z` + // * `0-9` + // * `_-+$@` + // 4. May only contain `:` when used via the namespacing operator, e.g. `std::math` + let mut char_indices = name.char_indices().peekable(); + let mut is_namespaced = false; + while let Some((offset, c)) = char_indices.next() { + match c { + c if c.is_ascii_alphanumeric() => continue, + '_' | '-' | '+' | '$' | '@' => continue, + ':' => match char_indices.peek() { + Some((_, ':')) => { + char_indices.next(); + is_namespaced = true; + continue; + } + _ => { + let pos = module.name.span().start() + offset; + let span = SourceSpan::new(pos, pos); + invalid_module!( + diagnostics, + module.name, + span, + "module name contains invalid character ':'", + "Did you mean to use the namespacing operator '::'?" + ); + } + }, + c if c.is_whitespace() => { + invalid_module!( + diagnostics, + module.name, + "module names may not contain whitespace" + ); + } + c => { + let pos = module.name.span().start() + offset; + let span = SourceSpan::new(pos, pos); + invalid_module!( + diagnostics, + module.name, + span, + "{c} is not valid in module names" + ); + } + } + } + + // 5. The namespacing operator may only appear between two valid module identifiers + // 6. Namespaced module names must adhere to the above rules in each submodule identifier + if is_namespaced { + let mut offset = 0; + for component in name.split("::") { + let len = component.as_bytes().len(); + let start = module.name.span().start() + offset; + let span = SourceSpan::new(start, start + len); + if component.is_empty() { + invalid_module!( + diagnostics, + module.name, + span, + "submodule names cannot be empty" + ); + } + + if !name.starts_with(is_lower_ascii_alphabetic) { + invalid_module!( + diagnostics, + module.name, + span, + "submodule name must start with a lowercase, ascii-alphabetic character" + ); + } + + offset += len + 2; + } + } + + Ok(()) + } +} +impl Rule for NamingConventions { + fn validate( + &mut self, + function: &Function, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + let name = function.id.function.as_str(); + let span = function.id.function.span(); + + // 1. Must not be empty + if name.is_empty() { + invalid_function!(diagnostics, function.id, "function names cannot be empty"); + } + + // 2. Must start with an ASCII-alphabetic character, underscore, `$` or `@` + fn name_starts_with(c: char) -> bool { + c.is_ascii_alphabetic() || matches!(c, '_' | '$' | '@') + } + + // 3. Otherwise, no restrictions, but may not contain whitespace + if let Err((offset, c)) = is_valid_identifier(name, name_starts_with, char::is_whitespace) { + if c.is_whitespace() { + let pos = span.start() + offset; + let span = SourceSpan::new(pos, pos); + invalid_function!( + diagnostics, + function.id, + span, + "function names may not contain whitespace" + ); + } else { + debug_assert_eq!(offset, 0); + let span = SourceSpan::new(span.start(), span.start()); + invalid_function!(diagnostics, function.id, span, "function names must start with an ascii-alphabetic character, '_', '$', or '@'"); + } + } + + Ok(()) + } +} +impl Rule for NamingConventions { + fn validate( + &mut self, + global: &GlobalVariableData, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + let span = global.name.span(); + let name = global.name.as_str(); + + // 1. Must not be empty + if name.is_empty() { + invalid_global!( + diagnostics, + global.name, + "global variable names cannot be empty" + ); + } + + // 2. Must start with an ASCII-alphabetic character, underscore, `.`, `$` or `@` + fn name_starts_with(c: char) -> bool { + c.is_ascii_alphabetic() || matches!(c, '_' | '.' | '$' | '@') + } + + // 3. Otherwise, no restrictions, but may not contain whitespace + if let Err((offset, c)) = is_valid_identifier(name, name_starts_with, char::is_whitespace) { + if c.is_whitespace() { + let pos = span.start() + offset; + let span = SourceSpan::new(pos, pos); + invalid_global!( + diagnostics, + global.name, + span, + "global variable names may not contain whitespace" + ); + } else { + debug_assert_eq!(offset, 0); + let span = SourceSpan::new(span.start(), span.start()); + invalid_global!(diagnostics, global.name, span, "global variable names must start with an ascii-alphabetic character, '_', '.', '$', or '@'"); + } + } + + Ok(()) + } +} + +#[inline(always)] +fn is_lower_ascii_alphabetic(c: char) -> bool { + c.is_ascii_alphabetic() && c.is_ascii_lowercase() +} + +/// This is necessary until [std::str::Pattern] is stabilized +trait Pattern { + fn matches(&self, c: char) -> bool; +} +impl Pattern for char { + #[inline(always)] + fn matches(&self, c: char) -> bool { + *self == c + } +} +impl Pattern for F +where + F: Fn(char) -> bool, +{ + #[inline(always)] + fn matches(&self, c: char) -> bool { + self(c) + } +} + +#[inline] +fn is_valid_identifier(id: &str, start_with: P1, forbidden: P2) -> Result<(), (usize, char)> +where + P1: Pattern, + P2: Pattern, +{ + for (offset, c) in id.char_indices() { + if offset == 0 && !start_with.matches(c) { + return Err((offset, c)); + } + + if forbidden.matches(c) { + return Err((offset, c)); + } + } + + Ok(()) +} diff --git a/hir-analysis/src/validation/typecheck.rs b/hir-analysis/src/validation/typecheck.rs new file mode 100644 index 000000000..a78e94a6c --- /dev/null +++ b/hir-analysis/src/validation/typecheck.rs @@ -0,0 +1,1133 @@ +use core::fmt; + +use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Spanned}; +use miden_hir::*; + +use rustc_hash::FxHashMap; + +use super::{Rule, ValidationError}; + +/// This error is produced when type checking the IR for function or module +#[derive(Debug, thiserror::Error, PartialEq, Eq)] +pub enum TypeError { + /// The number of arguments given does not match what is expected by the instruction + #[error("expected {expected} arguments, but {actual} are given")] + IncorrectArgumentCount { expected: usize, actual: usize }, + /// The number of results produced does not match what is expected from the instruction + #[error("expected {expected} results, but {actual} are produced")] + IncorrectResultCount { expected: usize, actual: usize }, + /// One of the arguments is not of the correct type + #[error("expected argument of {expected} type at index {index}, got {actual}")] + IncorrectArgumentType { + expected: TypePattern, + actual: Type, + index: usize, + }, + /// One of the results is not of the correct type + #[error("expected result of {expected} type at index {index}, got {actual}")] + InvalidResultType { + expected: TypePattern, + actual: Type, + index: usize, + }, + /// The number of arguments given to a successor block does not match what is expected by the block + #[error("{successor} expected {expected} arguments, but {actual} are given")] + IncorrectSuccessorArgumentCount { + successor: Block, + expected: usize, + actual: usize, + }, + /// One of the arguments to a successor block is not of the correct type + #[error("{successor} expected argument of {expected} type at index {index}, got {actual}")] + IncorrectSuccessorArgumentType { + successor: Block, + expected: Type, + actual: Type, + index: usize, + }, + /// An attempt was made to cast from a larger integer type to a smaller one via widening cast, e.g. `zext` + #[error("expected result to be an integral type larger than {expected}, but got {actual}")] + InvalidWideningCast { expected: Type, actual: Type }, + /// An attempt was made to cast from a smaller integer type to a larger one via narrowing cast, e.g. `trunc` + #[error("expected result to be an integral type smaller than {expected}, but got {actual}")] + InvalidNarrowingCast { expected: Type, actual: Type }, + /// The arguments of an instruction were supposed to be the same type, but at least one differs from the controlling type + #[error("expected arguments to be the same type ({expected}), but argument at index {index} is {actual}")] + MatchingArgumentTypeViolation { + expected: Type, + actual: Type, + index: usize, + }, + /// The result type of an instruction was supposed to be the same as the arguments, but it wasn't + #[error("expected result to be the same type ({expected}) as the arguments, but got {actual}")] + MatchingResultTypeViolation { expected: Type, actual: Type }, +} + +/// This validation rule type checks a block to catch any type violations by instructions in that block +pub struct TypeCheck<'a> { + signature: &'a Signature, + dfg: &'a DataFlowGraph, +} +impl<'a> TypeCheck<'a> { + pub fn new(signature: &'a Signature, dfg: &'a DataFlowGraph) -> Self { + Self { signature, dfg } + } +} +impl<'a> Rule for TypeCheck<'a> { + fn validate( + &mut self, + block_data: &BlockData, + diagnostics: &DiagnosticsHandler, + ) -> Result<(), ValidationError> { + // Traverse the block, checking each instruction in turn + for node in block_data.insts.iter() { + let span = node.span(); + let opcode = node.opcode(); + let results = self.dfg.inst_results(node.key); + let typechecker = InstTypeChecker::new(diagnostics, self.dfg, node)?; + + match node.as_ref() { + Instruction::UnaryOp(UnaryOp { arg, .. }) => match opcode { + Opcode::ImmI1 + | Opcode::ImmU8 + | Opcode::ImmI8 + | Opcode::ImmU16 + | Opcode::ImmI16 + | Opcode::ImmU32 + | Opcode::ImmI32 + | Opcode::ImmU64 + | Opcode::ImmI64 + | Opcode::ImmFelt + | Opcode::ImmF64 => invalid_instruction!( + diagnostics, + node.key, + span, + "immediate opcode '{opcode}' cannot be used with non-immediate argument" + ), + _ => { + typechecker.check(&[*arg], results)?; + } + }, + Instruction::UnaryOpImm(UnaryOpImm { imm, .. }) => match opcode { + Opcode::PtrToInt => invalid_instruction!( + diagnostics, + node.key, + span, + "'{opcode}' cannot be used with an immediate value" + ), + _ => { + typechecker.check_immediate(&[], imm, results)?; + } + }, + Instruction::Load(LoadOp { ref ty, addr, .. }) => { + if ty.size_in_felts() > 4 { + invalid_instruction!(diagnostics, node.key, span, "cannot load a value of type {ty} on the stack, as it is larger than 16 bytes"); + } + typechecker.check(&[*addr], results)?; + } + Instruction::BinaryOpImm(BinaryOpImm { imm, arg, .. }) => { + typechecker.check_immediate(&[*arg], imm, results)?; + } + Instruction::PrimOpImm(PrimOpImm { imm, args, .. }) => { + let args = args.as_slice(&self.dfg.value_lists); + typechecker.check_immediate(args, imm, results)?; + } + Instruction::GlobalValue(_) + | Instruction::BinaryOp(_) + | Instruction::PrimOp(_) + | Instruction::Test(_) + | Instruction::InlineAsm(_) + | Instruction::Call(_) => { + let args = node.arguments(&self.dfg.value_lists); + typechecker.check(args, results)?; + } + Instruction::Ret(Ret { ref args, .. }) => { + let args = args.as_slice(&self.dfg.value_lists); + if args.len() != self.signature.results.len() { + return Err(ValidationError::TypeError( + TypeError::IncorrectArgumentCount { + expected: self.signature.results.len(), + actual: args.len(), + }, + )); + } + for (index, (expected, arg)) in self + .signature + .results + .iter() + .zip(args.iter().copied()) + .enumerate() + { + let actual = self.dfg.value_type(arg); + if actual != &expected.ty { + return Err(ValidationError::TypeError( + TypeError::IncorrectArgumentType { + expected: expected.ty.clone().into(), + actual: actual.clone(), + index, + }, + )); + } + } + } + Instruction::RetImm(RetImm { ref arg, .. }) => { + if self.signature.results.len() != 1 { + return Err(ValidationError::TypeError( + TypeError::IncorrectArgumentCount { + expected: self.signature.results.len(), + actual: 1, + }, + )); + } + let expected = &self.signature.results[0].ty; + let actual = arg.ty(); + if &actual != expected { + return Err(ValidationError::TypeError( + TypeError::IncorrectArgumentType { + expected: expected.clone().into(), + actual, + index: 0, + }, + )); + } + } + Instruction::Br(Br { + ref args, + destination, + .. + }) => { + let successor = *destination; + let expected = self.dfg.block_args(successor); + let args = args.as_slice(&self.dfg.value_lists); + if args.len() != expected.len() { + return Err(ValidationError::TypeError( + TypeError::IncorrectSuccessorArgumentCount { + successor, + expected: expected.len(), + actual: args.len(), + }, + )); + } + for (index, (param, arg)) in expected + .iter() + .copied() + .zip(args.iter().copied()) + .enumerate() + { + let expected = self.dfg.value_type(param); + let actual = self.dfg.value_type(arg); + if actual != expected { + return Err(ValidationError::TypeError( + TypeError::IncorrectSuccessorArgumentType { + successor, + expected: expected.clone(), + actual: actual.clone(), + index, + }, + )); + } + } + } + Instruction::CondBr(CondBr { + cond, + then_dest: (then_dest, then_args), + else_dest: (else_dest, else_args), + .. + }) => { + typechecker.check(&[*cond], results)?; + + let then_dest = *then_dest; + let else_dest = *else_dest; + for (successor, dest_args) in + [(then_dest, then_args), (else_dest, else_args)].into_iter() + { + let expected = self.dfg.block_args(successor); + let args = dest_args.as_slice(&self.dfg.value_lists); + if args.len() != expected.len() { + return Err(ValidationError::TypeError( + TypeError::IncorrectSuccessorArgumentCount { + successor, + expected: expected.len(), + actual: args.len(), + }, + )); + } + for (index, (param, arg)) in expected + .iter() + .copied() + .zip(args.iter().copied()) + .enumerate() + { + let expected = self.dfg.value_type(param); + let actual = self.dfg.value_type(arg); + if actual != expected { + return Err(ValidationError::TypeError( + TypeError::IncorrectSuccessorArgumentType { + successor, + expected: expected.clone(), + actual: actual.clone(), + index, + }, + )); + } + } + } + } + Instruction::Switch(Switch { + arg, + arms, + default: fallback, + .. + }) => { + typechecker.check(&[*arg], results)?; + + let mut seen = FxHashMap::::default(); + for (i, (key, successor)) in arms.iter().enumerate() { + if let Some(prev) = seen.insert(*key, i) { + return Err(ValidationError::InvalidInstruction { span, inst: node.key, reason: format!("all arms of a 'switch' must have a unique discriminant, but the arm at index {i} has the same discriminant as the arm at {prev}") }); + } + + let expected = self.dfg.block_args(*successor); + if !expected.is_empty() { + return Err(ValidationError::InvalidInstruction { span, inst: node.key, reason: format!("all successors of a 'switch' must not have block parameters, but {successor}, the successor for discriminant {key}, has {} arguments", expected.len()) }); + } + } + let expected = self.dfg.block_args(*fallback); + if !expected.is_empty() { + return Err(ValidationError::InvalidInstruction { span, inst: node.key, reason: format!("all successors of a 'switch' must not have block parameters, but {fallback}, the default successor, has {} arguments", expected.len()) }); + } + } + } + } + + Ok(()) + } +} + +/// This type represents a match pattern over kinds of types. +/// +/// This is quite useful in the type checker, as otherwise we would have to handle many +/// type combinations for each instruction. +#[derive(Debug, PartialEq, Eq)] +pub enum TypePattern { + /// Matches any type + Any, + /// Matches any integer type + Int, + /// Matches any unsigned integer type + Uint, + /// Matches any signed integer type + Sint, + /// Matches any pointer type + Pointer, + /// Matches any primitive numeric or pointer type + Primitive, + /// Matches a specific type + Exact(Type), +} +impl TypePattern { + /// Returns true if this pattern matches `ty` + pub fn matches(&self, ty: &Type) -> bool { + match self { + Self::Any => true, + Self::Int => ty.is_integer(), + Self::Uint => ty.is_unsigned_integer(), + Self::Sint => ty.is_signed_integer(), + Self::Pointer => ty.is_pointer(), + Self::Primitive => ty.is_numeric() || ty.is_pointer(), + Self::Exact(expected) => expected.eq(ty), + } + } +} +impl From for TypePattern { + #[inline(always)] + fn from(ty: Type) -> Self { + Self::Exact(ty) + } +} +impl fmt::Display for TypePattern { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Any => f.write_str("any"), + Self::Int => f.write_str("integer"), + Self::Uint => f.write_str("unsigned integer"), + Self::Sint => f.write_str("signed integer"), + Self::Pointer => f.write_str("pointer"), + Self::Primitive => f.write_str("primitive"), + Self::Exact(ty) => write!(f, "{ty}"), + } + } +} + +/// This type represents kinds of instructions in terms of their argument and result types. +/// +/// Each instruction kind represents a category of instructions with similar semantics. +pub enum InstPattern { + /// The instruction matches if it has no arguments or results + Empty, + /// The instruction matches if it has one argument and one result, both of the given type + Unary(TypePattern), + /// The instruction matches if it has one argument of the given type and no results + UnaryNoResult(TypePattern), + /// The instruction matches if it has one argument of the first type and one result of the second type + /// + /// This is used to represent things like `inttoptr` or `ptrtoint` which map one type to another + UnaryMap(TypePattern, TypePattern), + /// The instruction matches if it has one argument of integral type, and one result of a larger integral type + UnaryWideningCast(TypePattern, TypePattern), + /// The instruction matches if it has one argument of integral type, and one result of a smaller integral type + UnaryNarrowingCast(TypePattern, TypePattern), + /// The instruction matches if it has two arguments of the given type, and one result which is the same type as the first argument + Binary(TypePattern, TypePattern), + /// The instruction matches if it has two arguments and one result, all of the same type + BinaryMatching(TypePattern), + /// The instruction matches if it has two arguments of the same type, and no results + BinaryMatchingNoResult(TypePattern), + /// The instruction matches if it has two arguments of the same type, and returns a boolean + BinaryPredicate(TypePattern), + /// The instruction matches if its first argument matches the first type, with two more arguments and one result matching the second type + /// + /// This is used to model instructions like `select` + TernaryMatching(TypePattern, TypePattern), + /// The instruciton matches if it has the exact number of arguments and results given, each corresponding to the given type + Exact(Vec, Vec), + /// The instruction matches any number of arguments and results, of any type + Any, +} +impl InstPattern { + /// Evaluate this pattern against the given arguments and results + pub fn into_match( + self, + dfg: &DataFlowGraph, + args: &[Value], + results: &[Value], + ) -> Result<(), TypeError> { + match self { + Self::Empty => { + if !args.is_empty() { + return Err(TypeError::IncorrectArgumentCount { + expected: 0, + actual: args.len(), + }); + } + if !results.is_empty() { + return Err(TypeError::IncorrectResultCount { + expected: 0, + actual: args.len(), + }); + } + Ok(()) + } + Self::Unary(_) + | Self::UnaryMap(_, _) + | Self::UnaryWideningCast(_, _) + | Self::UnaryNarrowingCast(_, _) => { + if args.len() != 1 { + return Err(TypeError::IncorrectArgumentCount { + expected: 1, + actual: args.len(), + }); + } + if results.len() != 1 { + return Err(TypeError::IncorrectResultCount { + expected: 1, + actual: results.len(), + }); + } + let actual_in = dfg.value_type(args[0]); + let actual_out = dfg.value_type(results[0]); + self.into_unary_match(actual_in, Some(actual_out)) + } + Self::UnaryNoResult(_) => { + if args.len() != 1 { + return Err(TypeError::IncorrectArgumentCount { + expected: 1, + actual: args.len(), + }); + } + if !results.is_empty() { + return Err(TypeError::IncorrectResultCount { + expected: 0, + actual: results.len(), + }); + } + let actual = dfg.value_type(args[0]); + self.into_unary_match(actual, None) + } + Self::Binary(_, _) | Self::BinaryMatching(_) | Self::BinaryPredicate(_) => { + if args.len() != 2 { + return Err(TypeError::IncorrectArgumentCount { + expected: 2, + actual: args.len(), + }); + } + if results.len() != 1 { + return Err(TypeError::IncorrectResultCount { + expected: 1, + actual: results.len(), + }); + } + let lhs = dfg.value_type(args[0]); + let rhs = dfg.value_type(args[1]); + let result = dfg.value_type(results[0]); + self.into_binary_match(lhs, rhs, Some(result)) + } + Self::BinaryMatchingNoResult(_) => { + if args.len() != 2 { + return Err(TypeError::IncorrectArgumentCount { + expected: 2, + actual: args.len(), + }); + } + if !results.is_empty() { + return Err(TypeError::IncorrectResultCount { + expected: 0, + actual: results.len(), + }); + } + let lhs = dfg.value_type(args[0]); + let rhs = dfg.value_type(args[1]); + self.into_binary_match(lhs, rhs, None) + } + Self::TernaryMatching(_, _) => { + if args.len() != 3 { + return Err(TypeError::IncorrectArgumentCount { + expected: 3, + actual: args.len(), + }); + } + if results.len() != 1 { + return Err(TypeError::IncorrectResultCount { + expected: 1, + actual: results.len(), + }); + } + let cond = dfg.value_type(args[0]); + let lhs = dfg.value_type(args[1]); + let rhs = dfg.value_type(args[2]); + let result = dfg.value_type(results[0]); + self.into_ternary_match(cond, lhs, rhs, result) + } + Self::Exact(expected_args, expected_results) => { + if args.len() != expected_args.len() { + return Err(TypeError::IncorrectArgumentCount { + expected: expected_args.len(), + actual: args.len(), + }); + } + if results.len() != expected_results.len() { + return Err(TypeError::IncorrectResultCount { + expected: expected_results.len(), + actual: results.len(), + }); + } + for (index, (expected, arg)) in expected_args + .into_iter() + .zip(args.iter().copied()) + .enumerate() + { + let actual = dfg.value_type(arg); + if !expected.matches(actual) { + return Err(TypeError::IncorrectArgumentType { + expected, + actual: actual.clone(), + index, + }); + } + } + for (index, (expected, result)) in expected_results + .into_iter() + .zip(results.iter().copied()) + .enumerate() + { + let actual = dfg.value_type(result); + if !expected.matches(actual) { + return Err(TypeError::InvalidResultType { + expected, + actual: actual.clone(), + index, + }); + } + } + + Ok(()) + } + Self::Any => Ok(()), + } + } + + /// Evaluate this pattern against the given arguments (including an immediate argument) and results + pub fn into_match_with_immediate( + self, + dfg: &DataFlowGraph, + args: &[Value], + imm: &Immediate, + results: &[Value], + ) -> Result<(), TypeError> { + match self { + Self::Empty => panic!("invalid empty pattern for instruction with immediate argument"), + Self::Unary(_) + | Self::UnaryMap(_, _) + | Self::UnaryWideningCast(_, _) + | Self::UnaryNarrowingCast(_, _) => { + if !args.is_empty() { + return Err(TypeError::IncorrectArgumentCount { + expected: 1, + actual: args.len() + 1, + }); + } + if results.len() != 1 { + return Err(TypeError::IncorrectResultCount { + expected: 1, + actual: results.len(), + }); + } + let actual_in = imm.ty(); + let actual_out = dfg.value_type(results[0]); + self.into_unary_match(&actual_in, Some(&actual_out)) + } + Self::UnaryNoResult(_) => { + if !args.is_empty() { + return Err(TypeError::IncorrectArgumentCount { + expected: 1, + actual: args.len() + 1, + }); + } + if !results.is_empty() { + return Err(TypeError::IncorrectResultCount { + expected: 0, + actual: results.len(), + }); + } + let actual = imm.ty(); + self.into_unary_match(&actual, None) + } + Self::Binary(_, _) | Self::BinaryMatching(_) | Self::BinaryPredicate(_) => { + if args.len() != 1 { + return Err(TypeError::IncorrectArgumentCount { + expected: 2, + actual: args.len() + 1, + }); + } + if results.len() != 1 { + return Err(TypeError::IncorrectResultCount { + expected: 1, + actual: results.len(), + }); + } + let lhs = dfg.value_type(args[0]); + let rhs = imm.ty(); + let result = dfg.value_type(results[0]); + self.into_binary_match(lhs, &rhs, Some(result)) + } + Self::BinaryMatchingNoResult(_) => { + if args.len() != 1 { + return Err(TypeError::IncorrectArgumentCount { + expected: 2, + actual: args.len() + 1, + }); + } + if !results.is_empty() { + return Err(TypeError::IncorrectResultCount { + expected: 0, + actual: results.len(), + }); + } + let lhs = dfg.value_type(args[0]); + let rhs = imm.ty(); + self.into_binary_match(lhs, &rhs, None) + } + Self::TernaryMatching(_, _) => { + if args.len() != 2 { + return Err(TypeError::IncorrectArgumentCount { + expected: 3, + actual: args.len() + 1, + }); + } + if results.len() != 1 { + return Err(TypeError::IncorrectResultCount { + expected: 1, + actual: results.len(), + }); + } + let cond = dfg.value_type(args[0]); + let lhs = dfg.value_type(args[1]); + let rhs = imm.ty(); + let result = dfg.value_type(results[0]); + self.into_ternary_match(cond, lhs, &rhs, result) + } + Self::Exact(expected_args, expected_results) => { + if args.len() != expected_args.len() { + return Err(TypeError::IncorrectArgumentCount { + expected: expected_args.len(), + actual: args.len(), + }); + } + if results.len() != expected_results.len() { + return Err(TypeError::IncorrectResultCount { + expected: expected_results.len(), + actual: results.len(), + }); + } + for (index, (expected, arg)) in expected_args + .into_iter() + .zip(args.iter().copied()) + .enumerate() + { + let actual = dfg.value_type(arg); + if !expected.matches(actual) { + return Err(TypeError::IncorrectArgumentType { + expected, + actual: actual.clone(), + index, + }); + } + } + for (index, (expected, result)) in expected_results + .into_iter() + .zip(results.iter().copied()) + .enumerate() + { + let actual = dfg.value_type(result); + if !expected.matches(actual) { + return Err(TypeError::InvalidResultType { + expected, + actual: actual.clone(), + index, + }); + } + } + + Ok(()) + } + Self::Any => Ok(()), + } + } + + fn into_unary_match( + self, + actual_in: &Type, + actual_out: Option<&Type>, + ) -> Result<(), TypeError> { + match self { + Self::Unary(expected) | Self::UnaryNoResult(expected) => { + if !expected.matches(actual_in) { + return Err(TypeError::IncorrectArgumentType { + expected, + actual: actual_in.clone(), + index: 0, + }); + } + if let Some(actual_out) = actual_out { + if actual_in != actual_out { + return Err(TypeError::MatchingResultTypeViolation { + expected: actual_in.clone(), + actual: actual_out.clone(), + }); + } + } + } + Self::UnaryMap(expected_in, expected_out) => { + if !expected_in.matches(actual_in) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_in, + actual: actual_in.clone(), + index: 0, + }); + } + let actual_out = actual_out.expect("expected result type"); + if !expected_out.matches(actual_out) { + return Err(TypeError::InvalidResultType { + expected: expected_out, + actual: actual_out.clone(), + index: 0, + }); + } + } + Self::UnaryWideningCast(expected_in, expected_out) => { + if !expected_in.matches(actual_in) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_in, + actual: actual_in.clone(), + index: 0, + }); + } + let actual_out = actual_out.expect("expected result type"); + if !expected_out.matches(actual_out) { + return Err(TypeError::InvalidResultType { + expected: expected_out, + actual: actual_out.clone(), + index: 0, + }); + } + if actual_in.size_in_bits() > actual_out.size_in_bits() { + return Err(TypeError::InvalidWideningCast { + expected: actual_in.clone(), + actual: actual_out.clone(), + }); + } + } + Self::UnaryNarrowingCast(expected_in, expected_out) => { + if !expected_in.matches(actual_in) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_in, + actual: actual_in.clone(), + index: 0, + }); + } + let actual_out = actual_out.expect("expected result type"); + if !expected_out.matches(actual_out) { + return Err(TypeError::InvalidResultType { + expected: expected_out, + actual: actual_out.clone(), + index: 0, + }); + } + if actual_in.size_in_bits() < actual_out.size_in_bits() { + return Err(TypeError::InvalidNarrowingCast { + expected: actual_in.clone(), + actual: actual_out.clone(), + }); + } + } + Self::Empty + | Self::Binary(_, _) + | Self::BinaryMatching(_) + | Self::BinaryMatchingNoResult(_) + | Self::BinaryPredicate(_) + | Self::TernaryMatching(_, _) + | Self::Exact(_, _) + | Self::Any => unreachable!(), + } + + Ok(()) + } + + fn into_binary_match( + self, + lhs: &Type, + rhs: &Type, + result: Option<&Type>, + ) -> Result<(), TypeError> { + match self { + Self::Binary(expected_lhs, expected_rhs) => { + if !expected_lhs.matches(lhs) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_lhs, + actual: lhs.clone(), + index: 0, + }); + } + if !expected_rhs.matches(rhs) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_rhs, + actual: rhs.clone(), + index: 1, + }); + } + let result = result.expect("expected result type"); + if lhs != result { + return Err(TypeError::MatchingResultTypeViolation { + expected: lhs.clone(), + actual: result.clone(), + }); + } + } + Self::BinaryMatching(expected) | Self::BinaryMatchingNoResult(expected) => { + if !expected.matches(lhs) { + return Err(TypeError::IncorrectArgumentType { + expected, + actual: lhs.clone(), + index: 0, + }); + } + if lhs != rhs { + return Err(TypeError::MatchingArgumentTypeViolation { + expected: lhs.clone(), + actual: rhs.clone(), + index: 1, + }); + } + if let Some(result) = result { + if lhs != result { + return Err(TypeError::MatchingResultTypeViolation { + expected: lhs.clone(), + actual: result.clone(), + }); + } + } + } + Self::BinaryPredicate(expected) => { + if !expected.matches(lhs) { + return Err(TypeError::IncorrectArgumentType { + expected, + actual: lhs.clone(), + index: 0, + }); + } + if lhs != rhs { + return Err(TypeError::MatchingArgumentTypeViolation { + expected: lhs.clone(), + actual: rhs.clone(), + index: 1, + }); + } + let result = result.expect("expected result type"); + let expected = Type::I1; + if result != &expected { + return Err(TypeError::MatchingResultTypeViolation { + expected, + actual: result.clone(), + }); + } + } + Self::Empty + | Self::Unary(_) + | Self::UnaryNoResult(_) + | Self::UnaryMap(_, _) + | Self::UnaryWideningCast(_, _) + | Self::UnaryNarrowingCast(_, _) + | Self::TernaryMatching(_, _) + | Self::Exact(_, _) + | Self::Any => unreachable!(), + } + + Ok(()) + } + + fn into_ternary_match( + self, + cond: &Type, + lhs: &Type, + rhs: &Type, + result: &Type, + ) -> Result<(), TypeError> { + match self { + Self::TernaryMatching(expected_cond, expected_inout) => { + if !expected_cond.matches(cond) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_cond, + actual: cond.clone(), + index: 0, + }); + } + if !expected_inout.matches(lhs) { + return Err(TypeError::IncorrectArgumentType { + expected: expected_inout, + actual: lhs.clone(), + index: 1, + }); + } + if lhs != rhs { + return Err(TypeError::IncorrectArgumentType { + expected: lhs.clone().into(), + actual: rhs.clone(), + index: 2, + }); + } + if lhs != result { + return Err(TypeError::MatchingResultTypeViolation { + expected: lhs.clone(), + actual: result.clone(), + }); + } + } + Self::Empty + | Self::Unary(_) + | Self::UnaryNoResult(_) + | Self::UnaryMap(_, _) + | Self::UnaryWideningCast(_, _) + | Self::UnaryNarrowingCast(_, _) + | Self::Binary(_, _) + | Self::BinaryMatching(_) + | Self::BinaryMatchingNoResult(_) + | Self::BinaryPredicate(_) + | Self::Exact(_, _) + | Self::Any => unreachable!(), + } + + Ok(()) + } +} + +/// This type plays the role of type checking instructions. +/// +/// It is separate from the [TypeCheck] rule itself to factor out +/// all the instruction-related boilerplate. +struct InstTypeChecker<'a> { + diagnostics: &'a DiagnosticsHandler, + dfg: &'a DataFlowGraph, + span: SourceSpan, + opcode: Opcode, + pattern: InstPattern, +} +impl<'a> InstTypeChecker<'a> { + /// Create a new instance of the type checker for the instruction represented by `node`. + pub fn new( + diagnostics: &'a DiagnosticsHandler, + dfg: &'a DataFlowGraph, + node: &InstNode, + ) -> Result { + let span = node.span(); + let opcode = node.opcode(); + let pattern = match opcode { + Opcode::Assert | Opcode::Assertz => InstPattern::UnaryNoResult(Type::I1.into()), + Opcode::AssertEq => InstPattern::BinaryMatchingNoResult(Type::I1.into()), + Opcode::ImmI1 => InstPattern::Unary(Type::I1.into()), + Opcode::ImmU8 => InstPattern::Unary(Type::U8.into()), + Opcode::ImmI8 => InstPattern::Unary(Type::I8.into()), + Opcode::ImmU16 => InstPattern::Unary(Type::U16.into()), + Opcode::ImmI16 => InstPattern::Unary(Type::I16.into()), + Opcode::ImmU32 => InstPattern::Unary(Type::U32.into()), + Opcode::ImmI32 => InstPattern::Unary(Type::I32.into()), + Opcode::ImmU64 => InstPattern::Unary(Type::U64.into()), + Opcode::ImmI64 => InstPattern::Unary(Type::I64.into()), + Opcode::ImmFelt => InstPattern::Unary(Type::Felt.into()), + Opcode::ImmF64 => InstPattern::Unary(Type::F64.into()), + Opcode::Alloca => InstPattern::Exact(vec![], vec![TypePattern::Pointer]), + Opcode::MemGrow => InstPattern::Unary(Type::U32.into()), + opcode @ Opcode::GlobalValue => match node.as_ref() { + Instruction::GlobalValue(GlobalValueOp { global, .. }) => { + match dfg.global_value(*global) { + GlobalValueData::Symbol { .. } | GlobalValueData::IAddImm { .. } => { + InstPattern::Exact(vec![], vec![TypePattern::Pointer]) + } + GlobalValueData::Load { ref ty, .. } => { + InstPattern::Exact(vec![], vec![ty.clone().into()]) + } + } + } + inst => panic!("invalid opcode '{opcode}' for {inst:#?}"), + }, + Opcode::Load => InstPattern::UnaryMap(TypePattern::Pointer, TypePattern::Any), + Opcode::Store => { + InstPattern::Exact(vec![TypePattern::Pointer, TypePattern::Any], vec![]) + } + Opcode::MemCpy => InstPattern::Exact( + vec![TypePattern::Pointer, TypePattern::Pointer, Type::U32.into()], + vec![], + ), + Opcode::PtrToInt => InstPattern::UnaryMap(TypePattern::Pointer, TypePattern::Int), + Opcode::IntToPtr => InstPattern::UnaryMap(TypePattern::Uint, TypePattern::Pointer), + Opcode::Cast => InstPattern::UnaryMap(TypePattern::Int, TypePattern::Int), + Opcode::Trunc => InstPattern::UnaryNarrowingCast(TypePattern::Int, TypePattern::Int), + Opcode::Zext => InstPattern::UnaryWideningCast(TypePattern::Int, TypePattern::Uint), + Opcode::Sext => InstPattern::UnaryWideningCast(TypePattern::Int, TypePattern::Int), + Opcode::Test => InstPattern::UnaryMap(TypePattern::Int, Type::I1.into()), + Opcode::Select => InstPattern::TernaryMatching(Type::I1.into(), TypePattern::Primitive), + Opcode::Add + | Opcode::Sub + | Opcode::Mul + | Opcode::Div + | Opcode::Mod + | Opcode::DivMod + | Opcode::Band + | Opcode::Bor + | Opcode::Bxor => InstPattern::BinaryMatching(TypePattern::Int), + Opcode::Exp | Opcode::Shl | Opcode::Shr | Opcode::Rotl | Opcode::Rotr => { + InstPattern::Binary(TypePattern::Int, TypePattern::Uint) + } + Opcode::Neg + | Opcode::Inv + | Opcode::Incr + | Opcode::Pow2 + | Opcode::Bnot + | Opcode::Popcnt => InstPattern::Unary(TypePattern::Int), + Opcode::Not => InstPattern::Unary(Type::I1.into()), + Opcode::And | Opcode::Or | Opcode::Xor => InstPattern::BinaryMatching(Type::I1.into()), + Opcode::Eq | Opcode::Neq => InstPattern::BinaryPredicate(TypePattern::Primitive), + Opcode::Gt | Opcode::Gte | Opcode::Lt | Opcode::Lte => { + InstPattern::BinaryPredicate(TypePattern::Int) + } + Opcode::IsOdd => InstPattern::Exact(vec![TypePattern::Int], vec![Type::I1.into()]), + Opcode::Min | Opcode::Max => InstPattern::BinaryMatching(TypePattern::Int), + Opcode::Call | Opcode::Syscall => match node.as_ref() { + Instruction::Call(Call { ref callee, .. }) => { + if let Some(import) = dfg.get_import(callee) { + let args = import + .signature + .params + .iter() + .map(|p| TypePattern::Exact(p.ty.clone())) + .collect(); + let results = import + .signature + .results + .iter() + .map(|p| TypePattern::Exact(p.ty.clone())) + .collect(); + InstPattern::Exact(args, results) + } else { + invalid_instruction!( + diagnostics, + node.key, + span, + "no signature is available for {callee}", + "Make sure you import functions before building calls to them." + ); + } + } + inst => panic!("invalid opcode '{opcode}' for {inst:#?}"), + }, + Opcode::Br => InstPattern::Any, + Opcode::CondBr => InstPattern::Exact(vec![Type::I1.into()], vec![]), + Opcode::Switch => InstPattern::Exact(vec![Type::U32.into()], vec![]), + Opcode::Ret => InstPattern::Any, + Opcode::Unreachable => InstPattern::Empty, + Opcode::InlineAsm => InstPattern::Any, + }; + Ok(Self { + diagnostics, + dfg, + span: node.span(), + opcode, + pattern, + }) + } + + /// Checks that the given `operands` and `results` match the types represented by this [InstTypeChecker] + pub fn check(self, operands: &[Value], results: &[Value]) -> Result<(), ValidationError> { + let diagnostics = self.diagnostics; + let dfg = self.dfg; + match self.pattern.into_match(dfg, operands, results) { + Ok(_) => Ok(()), + Err(err) => { + let opcode = self.opcode; + let message = format!("validation failed for {opcode} instruction"); + diagnostics + .diagnostic(Severity::Error) + .with_message(message.as_str()) + .with_primary_label(self.span, format!("{err}")) + .emit(); + Err(ValidationError::TypeError(err)) + } + } + } + + /// Checks that the given `operands` (with immediate) and `results` match the types represented by this [InstTypeChecker] + pub fn check_immediate( + self, + operands: &[Value], + imm: &Immediate, + results: &[Value], + ) -> Result<(), ValidationError> { + let diagnostics = self.diagnostics; + let dfg = self.dfg; + match self + .pattern + .into_match_with_immediate(dfg, operands, imm, results) + { + Ok(_) => Ok(()), + Err(err) => { + let opcode = self.opcode; + let message = format!("validation failed for {opcode} instruction"); + diagnostics + .diagnostic(Severity::Error) + .with_message(message.as_str()) + .with_primary_label(self.span, format!("{err}")) + .emit(); + Err(ValidationError::TypeError(err)) + } + } + } +}