diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 48e88c0e..2083d989 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -8,7 +8,7 @@ use huff_utils::{ ast::*, error::*, files, - prelude::{bytes32_to_string, hash_bytes, str_to_bytes32, Span}, + prelude::{bytes32_to_string, hash_bytes, str_to_bytes32, Opcode, Span}, token::{Token, TokenKind}, types::*, }; @@ -521,6 +521,51 @@ impl Parser { let macro_statements: Vec = self.parse_body()?; + let (body_statements_take, body_statements_return) = + macro_statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match st.ty { + StatementType::Literal(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + _ => (0i8, 0i8), + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let stack_takes = acc.0 + acc.1 - statement_takes as i16; + let stack_returns = if statement_takes as i16 > acc.1 { + statement_returns as i16 + } else { + acc.1 - statement_takes as i16 + statement_returns as i16 + }; + (stack_takes, stack_returns) + }); + + if outlined { + if body_statements_take.abs() != macro_takes as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some(format!( + "Fn specified to take {macro_takes} elements from the stack, but it takes {}", + body_statements_take.abs() + )), + spans: AstSpan(self.spans.clone()), + }); + } + if body_statements_return != macro_returns as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some(format!( + "Fn specified to return {macro_returns} elements to the stack, but it returns {}", + body_statements_return + )), + spans: AstSpan(self.spans.clone()), + }); + } + } + Ok(MacroDefinition::new( macro_name, decorator, diff --git a/huff_utils/src/error.rs b/huff_utils/src/error.rs index 9ecb2631..a7e635fc 100644 --- a/huff_utils/src/error.rs +++ b/huff_utils/src/error.rs @@ -63,6 +63,8 @@ pub enum ParserErrorKind { InvalidDecoratorFlag(String), /// Invalid decorator flag argument InvalidDecoratorFlagArg(TokenKind), + /// Invalid stack annotation + InvalidStackAnnotation(TokenKind), } /// A Lexing Error @@ -488,6 +490,14 @@ impl fmt::Display for CompilerError { pe.spans.error(pe.hint.as_ref()) ) } + ParserErrorKind::InvalidStackAnnotation(rt) => { + write!( + f, + "\nError: Invalid stack {} annotation in function definition \n{}\n", + rt, + pe.spans.error(pe.hint.as_ref()) + ) + } }, CompilerError::PathBufRead(os_str) => { write!( diff --git a/huff_utils/src/evm.rs b/huff_utils/src/evm.rs index 8ad61a40..96bdc276 100644 --- a/huff_utils/src/evm.rs +++ b/huff_utils/src/evm.rs @@ -836,6 +836,146 @@ impl Opcode { false } + + /// Returns stack changes by the given opcode + pub fn stack_changes(&self) -> (u8, u8) { + match self { + Opcode::Stop => (0, 0), + Opcode::Add | Opcode::Mul | Opcode::Sub | Opcode::Div => (2, 1), + Opcode::Sdiv | Opcode::Mod | Opcode::Smod => (2, 1), + Opcode::Addmod | Opcode::Mulmod => (3, 1), + Opcode::Exp => (2, 1), + Opcode::Signextend => (2, 1), + Opcode::Lt | Opcode::Gt => (2, 1), + Opcode::Slt | Opcode::Sgt => (2, 1), + Opcode::Eq => (2, 1), + Opcode::Iszero => (1, 1), + Opcode::And | Opcode::Or | Opcode::Xor => (2, 1), + Opcode::Not => (1, 1), + Opcode::Byte => (2, 1), + Opcode::Shl | Opcode::Shr | Opcode::Sar => (2, 1), + Opcode::Sha3 => (2, 1), + Opcode::Address => (0, 1), + Opcode::Balance => (1, 1), + Opcode::Origin => (0, 1), + Opcode::Caller => (0, 1), + Opcode::Callvalue => (0, 1), + Opcode::Calldataload => (1, 1), + Opcode::Calldatasize => (0, 1), + Opcode::Calldatacopy => (3, 0), + Opcode::Codesize => (0, 1), + Opcode::Codecopy => (3, 0), + Opcode::Gasprice => (0, 1), + Opcode::Extcodesize => (1, 1), + Opcode::Extcodecopy => (4, 0), + Opcode::Returndatasize => (0, 1), + Opcode::Returndatacopy => (3, 0), + Opcode::Extcodehash => (1, 1), + Opcode::Blockhash => (1, 1), + Opcode::Coinbase => (0, 1), + Opcode::Timestamp => (0, 1), + Opcode::Number => (0, 1), + Opcode::Difficulty => (0, 1), + Opcode::Prevrandao => (0, 1), + Opcode::Gaslimit => (0, 1), + Opcode::Chainid => (0, 1), + Opcode::Selfbalance => (0, 1), + Opcode::Basefee => (0, 1), + Opcode::Pop => (1, 0), + Opcode::Mload => (1, 1), + Opcode::Mstore | Opcode::Mstore8 => (2, 0), + Opcode::Sload => (1, 1), + Opcode::Sstore => (2, 0), + Opcode::Jump => (1, 0), + Opcode::Jumpi => (2, 0), + Opcode::Pc => (0, 1), + Opcode::Msize => (0, 1), + Opcode::Gas => (0, 1), + Opcode::Jumpdest => (0, 0), + Opcode::Push0 => (0, 1), + Opcode::Push1 => (0, 1), + Opcode::Push2 => (0, 1), + Opcode::Push3 => (0, 1), + Opcode::Push4 => (0, 1), + Opcode::Push5 => (0, 1), + Opcode::Push6 => (0, 1), + Opcode::Push7 => (0, 1), + Opcode::Push8 => (0, 1), + Opcode::Push9 => (0, 1), + Opcode::Push10 => (0, 1), + Opcode::Push11 => (0, 1), + Opcode::Push12 => (0, 1), + Opcode::Push13 => (0, 1), + Opcode::Push14 => (0, 1), + Opcode::Push15 => (0, 1), + Opcode::Push16 => (0, 1), + Opcode::Push17 => (0, 1), + Opcode::Push18 => (0, 1), + Opcode::Push19 => (0, 1), + Opcode::Push20 => (0, 1), + Opcode::Push21 => (0, 1), + Opcode::Push22 => (0, 1), + Opcode::Push23 => (0, 1), + Opcode::Push24 => (0, 1), + Opcode::Push25 => (0, 1), + Opcode::Push26 => (0, 1), + Opcode::Push27 => (0, 1), + Opcode::Push28 => (0, 1), + Opcode::Push29 => (0, 1), + Opcode::Push30 => (0, 1), + Opcode::Push31 => (0, 1), + Opcode::Push32 => (0, 1), + Opcode::Dup1 => (1, 2), + Opcode::Dup2 => (2, 3), + Opcode::Dup3 => (3, 4), + Opcode::Dup4 => (4, 5), + Opcode::Dup5 => (5, 6), + Opcode::Dup6 => (6, 7), + Opcode::Dup7 => (7, 8), + Opcode::Dup8 => (8, 9), + Opcode::Dup9 => (9, 10), + Opcode::Dup10 => (10, 11), + Opcode::Dup11 => (11, 12), + Opcode::Dup12 => (12, 13), + Opcode::Dup13 => (13, 14), + Opcode::Dup14 => (14, 15), + Opcode::Dup15 => (15, 16), + Opcode::Dup16 => (16, 17), + Opcode::Swap1 => (2, 2), + Opcode::Swap2 => (3, 3), + Opcode::Swap3 => (4, 4), + Opcode::Swap4 => (5, 5), + Opcode::Swap5 => (6, 6), + Opcode::Swap6 => (7, 7), + Opcode::Swap7 => (8, 8), + Opcode::Swap8 => (9, 9), + Opcode::Swap9 => (10, 10), + Opcode::Swap10 => (11, 11), + Opcode::Swap11 => (12, 12), + Opcode::Swap12 => (13, 13), + Opcode::Swap13 => (14, 14), + Opcode::Swap14 => (15, 15), + Opcode::Swap15 => (16, 16), + Opcode::Swap16 => (17, 17), + Opcode::Log0 => (2, 0), + Opcode::Log1 => (3, 0), + Opcode::Log2 => (4, 0), + Opcode::Log3 => (5, 0), + Opcode::Log4 => (6, 0), + Opcode::TLoad => (1, 1), + Opcode::TStore => (2, 0), + Opcode::Create => (3, 1), + Opcode::Call => (7, 1), + Opcode::Callcode => (7, 1), + Opcode::Return => (2, 0), + Opcode::Delegatecall => (6, 1), + Opcode::Create2 => (4, 1), + Opcode::Staticcall => (6, 1), + Opcode::Revert => (2, 0), + Opcode::Invalid => (0, 0), + Opcode::Selfdestruct => (1, 0), + } + } } impl fmt::Display for Opcode {