diff --git a/Cargo.lock b/Cargo.lock index 03eb66c40..f9c86bcbb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -905,6 +905,7 @@ dependencies = [ "miden-parsing", "paste", "petgraph", + "pretty_assertions", "rustc-hash", "smallvec", "thiserror", diff --git a/hir/Cargo.toml b/hir/Cargo.toml index 3e11ff6c5..2ecaefa88 100644 --- a/hir/Cargo.toml +++ b/hir/Cargo.toml @@ -25,6 +25,7 @@ miden-hir-type.workspace = true miden-parsing = "0.1" petgraph = "0.6" paste.workspace = true +pretty_assertions = "1.0" rustc-hash.workspace = true smallvec.workspace = true thiserror.workspace = true diff --git a/hir/src/parser/ast/block.rs b/hir/src/parser/ast/block.rs index 9e33b7c68..e15ec1373 100644 --- a/hir/src/parser/ast/block.rs +++ b/hir/src/parser/ast/block.rs @@ -6,6 +6,7 @@ const INDENT: &str = " "; /// Represents the label at the start of a basic block. /// /// Labels must be unique within each function. +#[derive(PartialEq, Debug)] pub struct Label { pub name: Ident, } @@ -21,6 +22,7 @@ impl fmt::Display for Label { } /// Represents an argument for a basic block +#[derive(PartialEq, Debug)] pub struct BlockArgument { pub value: Value, pub ty: Type, @@ -37,6 +39,7 @@ impl fmt::Display for BlockArgument { } /// Represents the label and the arguments of a basic block +#[derive(PartialEq, Debug)] pub struct BlockHeader { pub label: Label, pub args: Vec, @@ -66,7 +69,7 @@ impl fmt::Display for BlockHeader { } /// Represents a basic block of instructions -#[derive(Spanned)] +#[derive(Spanned, Debug)] pub struct Block { #[span] pub span: SourceSpan, @@ -82,6 +85,12 @@ impl Block { } } } +impl PartialEq for Block { + fn eq(&self, other: &Self) -> bool { + self.header == other.header + && self.instructions == other.instructions + } +} impl fmt::Display for Block { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "{}", self.header)?; diff --git a/hir/src/parser/ast/functions.rs b/hir/src/parser/ast/functions.rs index 332c83df6..83a42101c 100644 --- a/hir/src/parser/ast/functions.rs +++ b/hir/src/parser/ast/functions.rs @@ -2,6 +2,7 @@ use super::*; use crate::{ArgumentExtension, ArgumentPurpose, CallConv, FunctionIdent, Type}; /// The possible visibilities of a function +#[derive(PartialEq, Debug)] pub enum Visibility { /// (Module) private visibility Private, @@ -11,6 +12,7 @@ pub enum Visibility { /// A single parameter to a function. /// Parameter names are defined in the entry block for the function. +#[derive(PartialEq, Debug)] pub struct FunctionParameter { /// The purpose of the parameter (default or struct return) pub purpose: ArgumentPurpose, @@ -44,6 +46,7 @@ impl fmt::Display for FunctionParameter { } /// A single return value from a function. +#[derive(PartialEq, Debug)] pub struct FunctionReturn { /// The bit extension for the parameter pub extension: ArgumentExtension, @@ -67,7 +70,7 @@ impl fmt::Display for FunctionReturn { } /// Represents the type signature of a function -#[derive(Spanned)] +#[derive(Spanned, Debug)] pub struct FunctionSignature { #[span] pub span: SourceSpan, @@ -96,6 +99,15 @@ impl FunctionSignature { } } } +impl PartialEq for FunctionSignature { + fn eq(&self, other: &Self) -> bool { + self.visibility == other.visibility + && self.call_convention == other.call_convention + && self.name == other.name + && self.params == other.params + && self.returns == other.returns + } +} impl fmt::Display for FunctionSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.visibility { @@ -128,7 +140,7 @@ impl fmt::Display for FunctionSignature { } /// Represents the declaration of a function -#[derive(Spanned)] +#[derive(Spanned, Debug)] pub struct FunctionDeclaration { #[span] pub span: SourceSpan, @@ -144,6 +156,12 @@ impl FunctionDeclaration { } } } +impl PartialEq for FunctionDeclaration { + fn eq(&self, other: &Self) -> bool { + self.signature == other.signature + && self.blocks == other.blocks + } +} impl fmt::Display for FunctionDeclaration { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.signature)?; diff --git a/hir/src/parser/ast/globals.rs b/hir/src/parser/ast/globals.rs index 42ec9f846..d048b30cf 100644 --- a/hir/src/parser/ast/globals.rs +++ b/hir/src/parser/ast/globals.rs @@ -34,7 +34,7 @@ impl fmt::Display for GlobalVarInitializer { } /// This represents the declaration of a Miden IR global variable -#[derive(Spanned)] +#[derive(Spanned, Debug)] pub struct GlobalVarDeclaration { #[span] pub span: SourceSpan, @@ -60,6 +60,14 @@ impl GlobalVarDeclaration { self.init = Some(init) } } +impl PartialEq for GlobalVarDeclaration { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.ty == other.ty + && self.linkage == other.linkage + && self.init == other.init + } +} impl fmt::Display for GlobalVarDeclaration { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} {} {}", self.name, self.ty, self.linkage)?; diff --git a/hir/src/parser/ast/instruction.rs b/hir/src/parser/ast/instruction.rs index 034d2c8f1..b23ee8af4 100644 --- a/hir/src/parser/ast/instruction.rs +++ b/hir/src/parser/ast/instruction.rs @@ -5,6 +5,7 @@ use crate::{FunctionIdent, Ident, Overflow, Type}; /// /// All intermediate values are named, and have an associated [Value]. /// Value identifiers must be globally unique. +#[derive(PartialEq, Debug)] pub struct Value { pub name: Ident, } @@ -20,6 +21,7 @@ impl fmt::Display for Value { } /// Immediates are converted at a later stage +#[derive(PartialEq, Debug)] pub enum Immediate { Pos(u128), Neg(u128), @@ -39,7 +41,7 @@ impl fmt::Display for Immediate { /// An instruction consists of a single operation, and a number of values that /// represent the results of the operation. Additionally, the instruction contains /// the types of the produced results -#[derive(Spanned)] +#[derive(Spanned, Debug)] pub struct Instruction { #[span] pub span: SourceSpan, @@ -57,6 +59,13 @@ impl Instruction { } } } +impl PartialEq for Instruction { + fn eq(&self, other: &Self) -> bool { + self.values == other.values + && self.op == other.op + && self.types == other.types + } +} impl fmt::Display for Instruction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.values.is_empty() { @@ -83,6 +92,7 @@ impl fmt::Display for Instruction { } /// Represents a operation and its arguments +#[derive(PartialEq, Debug)] pub enum Operation { BinaryOp(BinaryOpCode, Value, Value), BinaryImmOp(BinaryImmOpCode, Value, Immediate), @@ -179,6 +189,7 @@ impl fmt::Display for Operation { } /// Used to distinguish between user calls and kernel calls +#[derive(PartialEq, Debug)] pub enum CallOp { Call, SysCall, @@ -193,6 +204,7 @@ impl fmt::Display for CallOp { } /// Used to distinguish between binary operations +#[derive(PartialEq, Debug)] pub enum BinaryOpCode { Add(Overflow), Sub(Overflow), @@ -287,6 +299,7 @@ impl fmt::Display for BinaryOpCode { } /// Used to distinguish between immediate binary operations +#[derive(PartialEq, Debug)] pub enum BinaryImmOpCode { AddImm(Overflow), SubImm(Overflow), @@ -367,6 +380,7 @@ impl fmt::Display for BinaryImmOpCode { } /// Used to distinguish between unary operations +#[derive(PartialEq, Debug)] pub enum UnaryOpCode { Inv, Incr, @@ -405,6 +419,7 @@ impl fmt::Display for UnaryOpCode { } /// Used to distinguish between immediate unary operations +#[derive(PartialEq, Debug)] pub enum UnaryImmOpCode { I1, I8, @@ -430,6 +445,7 @@ impl fmt::Display for UnaryImmOpCode { } /// Used to distinguish between primary operations +#[derive(PartialEq, Debug)] pub enum PrimOpCode { Select, Assert, @@ -453,6 +469,7 @@ impl fmt::Display for PrimOpCode { /// Memory offset for global variable reads. /// Conversion to i32 happens during transformation to hir. +#[derive(PartialEq, Debug)] pub enum Offset { Pos(u128), Neg(u128), @@ -483,6 +500,7 @@ impl fmt::Display for Offset { } /// Used to distinguish between nested global value operations +#[derive(PartialEq, Debug)] pub enum GlobalValueOpNested { Symbol(Ident, Offset), Load(Box, Offset), @@ -511,6 +529,7 @@ impl fmt::Display for GlobalValueOpNested { } /// Used to distinguish between top-level global value operations +#[derive(PartialEq, Debug)] pub enum GlobalValueOp { Symbol(Ident, Offset), Load(GlobalValueOpNested, Offset), @@ -544,6 +563,7 @@ impl fmt::Display for GlobalValueOp { } /// The destination of a branch/jump +#[derive(PartialEq, Debug)] pub struct Destination { pub label: Label, pub args: Vec, @@ -572,6 +592,7 @@ impl fmt::Display for Destination { } /// A branch of a switch operation +#[derive(PartialEq, Debug)] pub enum SwitchBranch { Test(u128, Label), Default(Label), diff --git a/hir/src/parser/ast/mod.rs b/hir/src/parser/ast/mod.rs index cdf6a0bc0..6aa1a6c51 100644 --- a/hir/src/parser/ast/mod.rs +++ b/hir/src/parser/ast/mod.rs @@ -17,7 +17,7 @@ use crate::Ident; /// This is a type alias used to clarify that an identifier refers to a module pub type ModuleId = Ident; -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ModuleType { /// Kernel context module Kernel, @@ -35,7 +35,7 @@ impl fmt::Display for ModuleType { /// This represents the parsed contents of a single Miden IR module /// -#[derive(Spanned)] +#[derive(Spanned, Debug)] pub struct Module { #[span] pub span: SourceSpan, @@ -66,6 +66,15 @@ impl Module { } } } +impl PartialEq for Module { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.ty == other.ty + && self.global_vars == other.global_vars + && self.functions == other.functions + && self.externals == other.externals + } +} impl fmt::Display for Module { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "{} {}", self.ty, self.name)?; diff --git a/hir/src/parser/parser/mod.rs b/hir/src/parser/parser/mod.rs index c645f33a3..24b5514f5 100644 --- a/hir/src/parser/parser/mod.rs +++ b/hir/src/parser/parser/mod.rs @@ -212,3 +212,6 @@ impl miden_parsing::Parse for ast::Module { } } } + +#[cfg(test)] +mod tests; diff --git a/hir/src/parser/parser/tests/input/system.mir b/hir/src/parser/parser/tests/input/system.mir new file mode 100644 index 000000000..797669a80 --- /dev/null +++ b/hir/src/parser/parser/tests/input/system.mir @@ -0,0 +1,12 @@ +module miden_ir_test + +global_1 u32 internal = 0xCAFEBABE + +pub cc(fast) fn miden_ir_test::test_func (zext u32, sext u32) -> u32 { + blk(v1 : u32, v2 : u32) : { + v3 = add.unchecked v1 v2 : u32 + ret (v1, v3) + } +} + +cc(kernel) fn exported::f1 (sret { u32, u32 }) -> [i8 ; 42]; diff --git a/hir/src/parser/parser/tests/utils.rs b/hir/src/parser/parser/tests/utils.rs new file mode 100644 index 000000000..e406fb73d --- /dev/null +++ b/hir/src/parser/parser/tests/utils.rs @@ -0,0 +1,156 @@ +use std::sync::Arc; + +use miden_diagnostics::{CodeMap, DiagnosticsConfig, DiagnosticsHandler, Emitter, Verbosity}; +use pretty_assertions::assert_eq; + +use crate::{ + parser::ast::Module, + parser::{ParseError, Parser}, +}; + +struct SplitEmitter { + capture: miden_diagnostics::CaptureEmitter, + default: miden_diagnostics::DefaultEmitter, +} +impl SplitEmitter { + #[inline] + pub fn new() -> Self { + use miden_diagnostics::term::termcolor::ColorChoice; + + Self { + capture: Default::default(), + default: miden_diagnostics::DefaultEmitter::new(ColorChoice::Auto), + } + } + + #[allow(unused)] + pub fn captured(&self) -> String { + self.capture.captured() + } +} +impl Emitter for SplitEmitter { + #[inline] + fn buffer(&self) -> miden_diagnostics::term::termcolor::Buffer { + self.capture.buffer() + } + + #[inline] + fn print(&self, buffer: miden_diagnostics::term::termcolor::Buffer) -> std::io::Result<()> { + use std::io::Write; + + let mut copy = self.capture.buffer(); + copy.write_all(buffer.as_slice())?; + self.capture.print(buffer)?; + self.default.print(copy) + } +} + +// TEST HANDLER +// ================================================================================================ + +/// [ParseTest] is a container for the data required to run parser tests. Used to build an AST from +/// the given source string and asserts that executing the test will result in the expected AST. +/// +/// # Errors: +/// - ScanError test: check that the source provided contains valid characters and keywords. +/// - ParseError test: check that the parsed values are valid. +/// * InvalidInt: This error is returned if the parsed number is not a valid u64. +pub struct ParseTest { + pub diagnostics: Arc, + #[allow(unused)] + emitter: Arc, + parser: Parser, +} + +impl ParseTest { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + + /// Creates a new test, from the source string. + pub fn new() -> Self { + let codemap = Arc::new(CodeMap::new()); + let emitter = Arc::new(SplitEmitter::new()); + let config = DiagnosticsConfig { + verbosity: Verbosity::Warning, + warnings_as_errors: true, + no_warn: false, + display: Default::default(), + }; + let diagnostics = Arc::new(DiagnosticsHandler::new( + config, + codemap.clone(), + emitter.clone(), + )); + let parser = Parser::new((), codemap); + Self { + diagnostics, + emitter, + parser, + } + } + + /// This adds a new in-memory file to the [CodeMap] for this test. + /// + /// This is used when we want to write a test with imports, without having to place files on disk + #[allow(unused)] + pub fn add_virtual_file>(&self, name: P, content: String) { + self.parser.codemap.add(name.as_ref(), content); + } + + pub fn parse_module_from_file(&self, path: &str) -> Result { + self.parser + .parse_file::(&self.diagnostics, path) + } + + #[allow(unused)] + pub fn parse_module(&self, source: &str) -> Result { + self.parser + .parse_string::(&self.diagnostics, source) + } + + // TEST METHODS + // -------------------------------------------------------------------------------------------- + + #[track_caller] + #[allow(unused)] + pub fn expect_module_diagnostic(&self, source: &str, expected: &str) { + if let Err(err) = self.parse_module(source) { + self.diagnostics.emit(err); + assert!( + self.emitter.captured().contains(expected), + "expected diagnostic output to contain the string: '{}'", + expected + ); + } else { + panic!("expected parsing to fail, but it succeeded"); + } + } + + /// Parses a [Module] from the given source string and asserts that executing the test will result + /// in the expected AST. + #[track_caller] + #[allow(unused)] + pub fn expect_module_ast(&self, source: &str, expected: Module) { + match self.parse_module(source) { + Err(err) => { + self.diagnostics.emit(err); + panic!("expected parsing to succeed, see diagnostics for details"); + } + Ok(ast) => assert_eq!(ast, expected), + } + } + + /// Parses a [Module] from the given source path and asserts that executing the test will result + /// in the expected AST. + #[allow(unused)] + #[track_caller] + pub fn expect_module_ast_from_file(&self, path: &str, expected: Module) { + match self.parse_module_from_file(path) { + Err(err) => { + self.diagnostics.emit(err); + panic!("expected parsing to succeed, see diagnostics for details"); + } + Ok(ast) => assert_eq!(ast, expected), + } + } +}