From e2ac9225fb146fa4b2da80fef228812822e4d7a2 Mon Sep 17 00:00:00 2001 From: Jacob Johannsen Date: Thu, 29 Jun 2023 16:58:34 +0200 Subject: [PATCH] Ast formatter --- assembly/src/ast/format.rs | 134 +++++++++++++++++++++++++++ assembly/src/ast/imports.rs | 12 ++- assembly/src/ast/mod.rs | 96 +++++++++++++++++++- assembly/src/ast/nodes/format.rs | 151 +++++++++++++++++++++++++++++++ assembly/src/ast/nodes/mod.rs | 29 ++---- assembly/src/library/path.rs | 8 +- assembly/src/procedures/mod.rs | 6 ++ 7 files changed, 406 insertions(+), 30 deletions(-) create mode 100644 assembly/src/ast/format.rs create mode 100644 assembly/src/ast/nodes/format.rs diff --git a/assembly/src/ast/format.rs b/assembly/src/ast/format.rs new file mode 100644 index 0000000000..a8634fe1dc --- /dev/null +++ b/assembly/src/ast/format.rs @@ -0,0 +1,134 @@ +use super::{ + CodeBody, FormattableNode, InvokedProcsMap, LibraryPath, ProcedureAst, ProcedureId, + ProcedureName, Vec, +}; +use core::fmt; + +const INDENT_STRING: &str = " "; + +/// Context for the Ast formatter +/// +/// The context keeps track of the current indentation level, as well as the declared and imported +/// procedures in the program/module being formatted. +pub struct AstFormatterContext<'a> { + indent_level: usize, + local_procs: &'a Vec, + imported_procs: &'a InvokedProcsMap, +} + +impl<'a> AstFormatterContext<'a> { + pub fn new( + local_procs: &'a Vec, + imported_procs: &'a InvokedProcsMap, + ) -> AstFormatterContext<'a> { + Self { + indent_level: 0, + local_procs, + imported_procs, + } + } + + /// Build a context for the inner scope, e.g., the body of a while loop + pub fn inner_scope_context(&self) -> Self { + Self { + indent_level: self.indent_level + 1, + local_procs: self.local_procs, + imported_procs: self.imported_procs, + } + } + + /// Add indentation to the current line in the formatter + pub fn indent(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for _ in 0..self.indent_level { + write!(f, "{INDENT_STRING}")?; + } + Ok(()) + } + + /// Get the name of the local procedure with the given index. + /// + /// # Panics + /// Panics if the index is not associated with a procedure name + pub fn local_proc(&self, index: usize) -> &ProcedureName { + assert!(index < self.local_procs.len(), "Local procedure with index {index} not found"); + &self.local_procs[index].name + } + + /// Get the name of the imported procedure with the given id/hash. + /// + /// # Panics + /// Panics if the id/hash is not associated with an imported procedure + pub fn imported_proc(&self, id: &ProcedureId) -> &(ProcedureName, LibraryPath) { + self.imported_procs + .get(id) + .expect("Imported procedure with id/hash {id} not found") + } +} + +// FORMATTING OF PROCEDURES +// ================================================================================================ +pub struct FormattableProcedureAst<'a> { + proc: &'a ProcedureAst, + context: &'a AstFormatterContext<'a>, +} + +impl<'a> FormattableProcedureAst<'a> { + pub fn new( + proc: &'a ProcedureAst, + context: &'a AstFormatterContext<'a>, + ) -> FormattableProcedureAst<'a> { + Self { proc, context } + } +} + +impl fmt::Display for FormattableProcedureAst<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Docs + self.context.indent(f)?; + if let Some(ref doc) = self.proc.docs { + writeln!(f, "#! {doc}")?; + } + // Procedure header + self.context.indent(f)?; + if self.proc.is_export { + write!(f, "export.")?; + } else { + write!(f, "proc.")?; + } + writeln!(f, "{}.{}", self.proc.name, self.proc.num_locals)?; + // Body + write!( + f, + "{}", + FormattableCodeBody::new(&self.proc.body, &self.context.inner_scope_context()) + )?; + // Procedure footer + self.context.indent(f)?; + writeln!(f, "end") + } +} + +// FORMATTING OF CODE BODIES +// ================================================================================================ +pub struct FormattableCodeBody<'a> { + body: &'a CodeBody, + context: &'a AstFormatterContext<'a>, +} + +impl<'a> FormattableCodeBody<'a> { + pub fn new( + body: &'a CodeBody, + context: &'a AstFormatterContext<'a>, + ) -> FormattableCodeBody<'a> { + Self { body, context } + } +} + +impl fmt::Display for FormattableCodeBody<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for node in self.body.nodes() { + write!(f, "{}", FormattableNode::new(node, self.context))?; + } + Ok(()) + } +} diff --git a/assembly/src/ast/imports.rs b/assembly/src/ast/imports.rs index a0e0f416e4..4895a9a759 100644 --- a/assembly/src/ast/imports.rs +++ b/assembly/src/ast/imports.rs @@ -1,14 +1,13 @@ use super::{ - BTreeMap, ByteReader, ByteWriter, Deserializable, DeserializationError, LibraryPath, - ParsingError, ProcedureId, ProcedureName, Serializable, String, ToString, Token, TokenStream, - Vec, MAX_IMPORTS, MAX_INVOKED_IMPORTED_PROCS, + BTreeMap, ByteReader, ByteWriter, Deserializable, DeserializationError, InvokedProcsMap, + LibraryPath, ParsingError, ProcedureId, ProcedureName, Serializable, String, ToString, Token, + TokenStream, Vec, MAX_IMPORTS, MAX_INVOKED_IMPORTED_PROCS, }; // TYPE ALIASES // ================================================================================================ type ImportedModulesMap = BTreeMap; -type InvokedProcsMap = BTreeMap; // MODULE IMPORTS // ================================================================================================ @@ -89,6 +88,11 @@ impl ModuleImports { self.imports.values().collect() } + /// Returns a reference to the invoked procedure map which maps procedure IDs to their names. + pub fn invoked_procs(&self) -> &InvokedProcsMap { + &self.invoked_procs + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- diff --git a/assembly/src/ast/mod.rs b/assembly/src/ast/mod.rs index 35d764f1ad..d30ae6f355 100644 --- a/assembly/src/ast/mod.rs +++ b/assembly/src/ast/mod.rs @@ -9,17 +9,21 @@ use super::{ Serializable, SliceReader, StarkField, String, ToString, Token, TokenStream, Vec, MAX_LABEL_LEN, }; -use core::{iter, str::from_utf8}; +use core::{fmt, iter, str::from_utf8}; use vm_core::utils::bound_into_included_u64; pub use super::tokens::SourceLocation; mod nodes; +use nodes::FormattableNode; pub use nodes::{AdviceInjectorNode, Instruction, Node}; mod code_body; pub use code_body::CodeBody; +mod format; +use format::*; + mod imports; pub use imports::ModuleImports; @@ -66,6 +70,7 @@ const MAX_STACK_WORD_OFFSET: u8 = 12; type LocalProcMap = BTreeMap; type LocalConstMap = BTreeMap; type ReExportedProcMap = BTreeMap; +type InvokedProcsMap = BTreeMap; // EXECUTABLE PROGRAM AST // ================================================================================================ @@ -211,7 +216,6 @@ impl ProgramAst { let local_procs = sort_procs_into_vec(context.local_procs); let (nodes, locations) = body.into_parts(); - Ok(Self::new(nodes, local_procs)? .with_source_locations(locations, start) .with_import_info(import_info)) @@ -325,6 +329,46 @@ impl ProgramAst { } } +impl fmt::Display for ProgramAst { + /// Writes this [ProgramAst] as formatted MASM code into the formatter. + /// + /// The formatted code puts each instruction on a separate line and preserves correct indentation + /// for instruction blocks. + /// + /// # Panics + /// Panics if import info is not associated with this program. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + assert!(self.import_info.is_some(), "Program imports not instantiated"); + + // Imports + if let Some(ref info) = self.import_info { + let paths = info.import_paths(); + for path in paths.iter() { + writeln!(f, "use.{path}")?; + } + if !paths.is_empty() { + writeln!(f)?; + } + } + + let tmp_procs = InvokedProcsMap::new(); + let invoked_procs = + self.import_info.as_ref().map(|info| info.invoked_procs()).unwrap_or(&tmp_procs); + + let context = AstFormatterContext::new(&self.local_procs, invoked_procs); + + // Local procedures + for proc in self.local_procs.iter() { + writeln!(f, "{}", FormattableProcedureAst::new(proc, &context))?; + } + + // Main progrma + writeln!(f, "begin")?; + write!(f, "{}", FormattableCodeBody::new(&self.body, &context.inner_scope_context()))?; + writeln!(f, "end") + } +} + // MODULE AST // ================================================================================================ @@ -594,6 +638,54 @@ impl ModuleAst { } } +impl fmt::Display for ModuleAst { + /// Writes this [ModuleAst] as formatted MASM code into the formatter. + /// + /// The formatted code puts each instruction on a separate line and preserves correct indentation + /// for instruction blocks. + /// + /// # Panics + /// Panics if import info is not associated with this module. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + assert!(self.import_info.is_some(), "Program imports not instantiated"); + + // Docs + if let Some(ref doc) = self.docs { + writeln!(f, "#! {doc}")?; + writeln!(f)?; + } + + // Imports + if let Some(ref info) = self.import_info { + let paths = info.import_paths(); + for path in paths.iter() { + writeln!(f, "use.{path}")?; + } + if !paths.is_empty() { + writeln!(f)?; + } + } + + // Re-exports + for proc in self.reexported_procs.iter() { + writeln!(f, "export.{}", proc.name)?; + writeln!(f)?; + } + + // Local procedures + let tmp_procs = InvokedProcsMap::new(); + let invoked_procs = + self.import_info.as_ref().map(|info| info.invoked_procs()).unwrap_or(&tmp_procs); + + let context = AstFormatterContext::new(&self.local_procs, invoked_procs); + + for proc in self.local_procs.iter() { + writeln!(f, "{}", FormattableProcedureAst::new(proc, &context))?; + } + Ok(()) + } +} + // PROCEDURE AST // ================================================================================================ diff --git a/assembly/src/ast/nodes/format.rs b/assembly/src/ast/nodes/format.rs new file mode 100644 index 0000000000..0136820fc6 --- /dev/null +++ b/assembly/src/ast/nodes/format.rs @@ -0,0 +1,151 @@ +use super::{AstFormatterContext, FormattableCodeBody, Instruction, Node}; +use core::fmt; + +// FORMATTING OF NODES +// ================================================================================================ +pub struct FormattableNode<'a> { + node: &'a Node, + context: &'a AstFormatterContext<'a>, +} + +impl<'a> FormattableNode<'a> { + pub fn new(node: &'a Node, context: &'a AstFormatterContext<'a>) -> Self { + Self { node, context } + } +} + +impl fmt::Display for FormattableNode<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.node { + Node::Instruction(i) => { + write!(f, "{}", FormattableInstruction::new(i, self.context)) + } + Node::IfElse { + true_case, + false_case, + } => { + self.context.indent(f)?; + writeln!(f, "if.true")?; + write!( + f, + "{}", + FormattableCodeBody::new(true_case, &self.context.inner_scope_context()) + )?; + if !false_case.nodes().is_empty() { + // No false branch - don't output else branch + self.context.indent(f)?; + writeln!(f, "else")?; + + write!( + f, + "{}", + FormattableCodeBody::new(false_case, &self.context.inner_scope_context()) + )?; + } + self.context.indent(f)?; + writeln!(f, "end") + } + Node::Repeat { times, body } => { + self.context.indent(f)?; + writeln!(f, "repeat.{times}")?; + + write!( + f, + "{}", + FormattableCodeBody::new(body, &self.context.inner_scope_context()) + )?; + + self.context.indent(f)?; + writeln!(f, "end") + } + Node::While { body } => { + self.context.indent(f)?; + writeln!(f, "while.true")?; + + write!( + f, + "{}", + FormattableCodeBody::new(body, &self.context.inner_scope_context()) + )?; + + self.context.indent(f)?; + writeln!(f, "end") + } + } + } +} + +// FORMATTING OF INSTRUCTIONS WITH INDENTATION +// ================================================================================================ +pub struct FormattableInstruction<'a> { + instruction: &'a Instruction, + context: &'a AstFormatterContext<'a>, +} + +impl<'a> FormattableInstruction<'a> { + pub fn new(instruction: &'a Instruction, context: &'a AstFormatterContext<'a>) -> Self { + Self { + instruction, + context, + } + } +} + +impl fmt::Display for FormattableInstruction<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.context.indent(f)?; + match self.instruction { + // procedure calls are represented by indices or hashes, so must be handled specially + Instruction::ExecLocal(index) => { + let proc_name = self.context.local_proc(*index as usize); + write!(f, "exec.{proc_name}")?; + } + Instruction::CallLocal(index) => { + let proc_name = self.context.local_proc(*index as usize); + write!(f, "call.{proc_name}")?; + } + Instruction::ExecImported(proc_id) => { + let (_, path) = self.context.imported_proc(proc_id); + write!(f, "exec.{path}")?; + } + Instruction::CallImported(proc_id) => { + let (_, path) = self.context.imported_proc(proc_id); + write!(f, "call.{path}")?; + } + Instruction::SysCall(proc_id) => { + let (_, path) = self.context.imported_proc(proc_id); + write!(f, "syscall.{path}")?; + } + Instruction::CallMastRoot(root) => { + write!(f, "call.")?; + display_hex_bytes(f, &root.as_bytes())?; + } + _ => { + // Not a procedure call. Use the normal formatting + write!(f, "{}", self.instruction)?; + } + } + writeln!(f) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Builds a hex string from a byte slice +pub fn display_hex_bytes(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { + write!(f, "0x")?; + for byte in bytes { + write!(f, "{byte:02x}")?; + } + Ok(()) +} + +/// Builds a string from input vector to display push operation +pub fn display_push_vec(f: &mut fmt::Formatter<'_>, values: &[T]) -> fmt::Result { + write!(f, "push")?; + for elem in values { + write!(f, ".{elem}")?; + } + Ok(()) +} diff --git a/assembly/src/ast/nodes/mod.rs b/assembly/src/ast/nodes/mod.rs index a8ed52a0bc..2eed169b34 100644 --- a/assembly/src/ast/nodes/mod.rs +++ b/assembly/src/ast/nodes/mod.rs @@ -1,9 +1,14 @@ -use super::{CodeBody, Felt, ProcedureId, RpoDigest, ToString, Vec}; +use super::{ + AstFormatterContext, CodeBody, Felt, FormattableCodeBody, ProcedureId, RpoDigest, ToString, Vec, +}; use core::fmt; mod advice; pub use advice::AdviceInjectorNode; +mod format; +pub use format::*; + mod serde; // NODES @@ -556,7 +561,6 @@ impl fmt::Display for Instruction { Self::FriExt2Fold4 => write!(f, "fri_ext2fold4"), // ----- exec / call ------------------------------------------------------------------ - // TODO: print exec/call instructions with procedures names, not indexes or id's Self::ExecLocal(index) => write!(f, "exec.{index}"), Self::ExecImported(proc_id) => write!(f, "exec.{proc_id}"), Self::CallLocal(index) => write!(f, "call.{index}"), @@ -573,27 +577,6 @@ impl fmt::Display for Instruction { } } -// HELPER FUNCTIONS -// ================================================================================================ - -/// Builds a hex string from a byte slice -pub fn display_hex_bytes(f: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { - write!(f, "0x")?; - for byte in bytes { - write!(f, "{byte:02x}")?; - } - Ok(()) -} - -/// Builds a string from input vector to display push operation -fn display_push_vec(f: &mut fmt::Formatter<'_>, values: &[T]) -> fmt::Result { - write!(f, "push")?; - for elem in values { - write!(f, ".{elem}")?; - } - Ok(()) -} - // TESTS // ================================================================================================ diff --git a/assembly/src/library/path.rs b/assembly/src/library/path.rs index 7c4ad3210a..393b854cb1 100644 --- a/assembly/src/library/path.rs +++ b/assembly/src/library/path.rs @@ -2,7 +2,7 @@ use super::{ ByteReader, ByteWriter, Deserializable, DeserializationError, PathError, Serializable, String, ToString, MAX_LABEL_LEN, }; -use core::{ops::Deref, str::from_utf8}; +use core::{fmt, ops::Deref, str::from_utf8}; // CONSTANTS // ================================================================================================ @@ -312,6 +312,12 @@ impl Deserializable for LibraryPath { } } +impl fmt::Display for LibraryPath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.path) + } +} + // HELPER FUNCTIONS // ================================================================================================ diff --git a/assembly/src/procedures/mod.rs b/assembly/src/procedures/mod.rs index da61695f4c..ec5501a801 100644 --- a/assembly/src/procedures/mod.rs +++ b/assembly/src/procedures/mod.rs @@ -182,6 +182,12 @@ impl Deserializable for ProcedureName { } } +impl fmt::Display for ProcedureName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} + // PROCEDURE ID // ================================================================================================