diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 48e88c0e..1d2c59c4 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -521,6 +521,51 @@ impl Parser { let macro_statements: Vec = self.parse_body()?; + let (body_statements_take, body_statements_return) = + macro_statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match st.ty { + StatementType::Literal(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + _ => (0i8, 0i8), + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let stack_takes = acc.0 + acc.1 - statement_takes as i16; + let stack_returns = if statement_takes as i16 > acc.1 { + statement_returns as i16 + } else { + acc.1 - statement_takes as i16 + statement_returns as i16 + }; + (stack_takes, stack_returns) + }); + + if outlined { + if body_statements_take.abs() != macro_takes as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some(format!( + "Fn {macro_name} specified to take {macro_takes} elements from the stack, but it takes {}", + body_statements_take.abs() + )), + spans: AstSpan(self.spans.clone()), + }); + } + if body_statements_return != macro_returns as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some(format!( + "Fn {macro_name} specified to return {macro_returns} elements to the stack, but it returns {}", + body_statements_return + )), + spans: AstSpan(self.spans.clone()), + }); + } + } + Ok(MacroDefinition::new( macro_name, decorator, diff --git a/huff_utils/src/error.rs b/huff_utils/src/error.rs index 9ecb2631..a7e635fc 100644 --- a/huff_utils/src/error.rs +++ b/huff_utils/src/error.rs @@ -63,6 +63,8 @@ pub enum ParserErrorKind { InvalidDecoratorFlag(String), /// Invalid decorator flag argument InvalidDecoratorFlagArg(TokenKind), + /// Invalid stack annotation + InvalidStackAnnotation(TokenKind), } /// A Lexing Error @@ -488,6 +490,14 @@ impl fmt::Display for CompilerError { pe.spans.error(pe.hint.as_ref()) ) } + ParserErrorKind::InvalidStackAnnotation(rt) => { + write!( + f, + "\nError: Invalid stack {} annotation in function definition \n{}\n", + rt, + pe.spans.error(pe.hint.as_ref()) + ) + } }, CompilerError::PathBufRead(os_str) => { write!( diff --git a/huff_utils/src/evm.rs b/huff_utils/src/evm.rs index 8ad61a40..f485602c 100644 --- a/huff_utils/src/evm.rs +++ b/huff_utils/src/evm.rs @@ -836,6 +836,146 @@ impl Opcode { false } + + /// Returns stack changes by the given opcode + pub fn stack_changes(&self) -> (u8, u8) { + match self { + Opcode::Stop => (0, 0), + Opcode::Add | Opcode::Mul | Opcode::Sub | Opcode::Div => (2, 1), + Opcode::Sdiv | Opcode::Mod | Opcode::Smod => (2, 1), + Opcode::Addmod | Opcode::Mulmod => (3, 1), + Opcode::Exp => (2, 1), + Opcode::Signextend => (2, 1), + Opcode::Lt | Opcode::Gt => (2, 1), + Opcode::Slt | Opcode::Sgt => (2, 1), + Opcode::Eq => (2, 1), + Opcode::Iszero => (1, 1), + Opcode::And | Opcode::Or | Opcode::Xor => (2, 1), + Opcode::Not => (1, 1), + Opcode::Byte => (2, 1), + Opcode::Shl | Opcode::Shr | Opcode::Sar => (2, 1), + Opcode::Sha3 => (2, 1), + Opcode::Address => (0, 1), + Opcode::Balance => (1, 1), + Opcode::Origin => (0, 1), + Opcode::Caller => (0, 1), + Opcode::Callvalue => (0, 1), + Opcode::Calldataload => (1, 1), + Opcode::Calldatasize => (0, 1), + Opcode::Calldatacopy => (3, 0), + Opcode::Codesize => (0, 1), + Opcode::Codecopy => (3, 0), + Opcode::Gasprice => (0, 1), + Opcode::Extcodesize => (1, 1), + Opcode::Extcodecopy => (4, 0), + Opcode::Returndatasize => (0, 1), + Opcode::Returndatacopy => (3, 0), + Opcode::Extcodehash => (1, 1), + Opcode::Blockhash => (1, 1), + Opcode::Coinbase => (0, 1), + Opcode::Timestamp => (0, 1), + Opcode::Number => (0, 1), + Opcode::Difficulty => (0, 1), + Opcode::Prevrandao => (0, 1), + Opcode::Gaslimit => (0, 1), + Opcode::Chainid => (0, 1), + Opcode::Selfbalance => (0, 1), + Opcode::Basefee => (0, 1), + Opcode::Pop => (1, 0), + Opcode::Mload => (1, 1), + Opcode::Mstore | Opcode::Mstore8 => (2, 0), + Opcode::Sload => (1, 1), + Opcode::Sstore => (2, 0), + Opcode::Jump => (1, 0), + Opcode::Jumpi => (2, 0), + Opcode::Pc => (0, 1), + Opcode::Msize => (0, 1), + Opcode::Gas => (0, 1), + Opcode::Jumpdest => (0, 0), + Opcode::Push0 => (0, 1), + Opcode::Push1 => (0, 1), + Opcode::Push2 => (0, 1), + Opcode::Push3 => (0, 1), + Opcode::Push4 => (0, 1), + Opcode::Push5 => (0, 1), + Opcode::Push6 => (0, 1), + Opcode::Push7 => (0, 1), + Opcode::Push8 => (0, 1), + Opcode::Push9 => (0, 1), + Opcode::Push10 => (0, 1), + Opcode::Push11 => (0, 1), + Opcode::Push12 => (0, 1), + Opcode::Push13 => (0, 1), + Opcode::Push14 => (0, 1), + Opcode::Push15 => (0, 1), + Opcode::Push16 => (0, 1), + Opcode::Push17 => (0, 1), + Opcode::Push18 => (0, 1), + Opcode::Push19 => (0, 1), + Opcode::Push20 => (0, 1), + Opcode::Push21 => (0, 1), + Opcode::Push22 => (0, 1), + Opcode::Push23 => (0, 1), + Opcode::Push24 => (0, 1), + Opcode::Push25 => (0, 1), + Opcode::Push26 => (0, 1), + Opcode::Push27 => (0, 1), + Opcode::Push28 => (0, 1), + Opcode::Push29 => (0, 1), + Opcode::Push30 => (0, 1), + Opcode::Push31 => (0, 1), + Opcode::Push32 => (0, 1), + Opcode::Dup1 => (0, 1), + Opcode::Dup2 => (0, 1), + Opcode::Dup3 => (0, 1), + Opcode::Dup4 => (0, 1), + Opcode::Dup5 => (0, 1), + Opcode::Dup6 => (0, 1), + Opcode::Dup7 => (0, 1), + Opcode::Dup8 => (0, 1), + Opcode::Dup9 => (0, 1), + Opcode::Dup10 => (0, 1), + Opcode::Dup11 => (0, 1), + Opcode::Dup12 => (0, 1), + Opcode::Dup13 => (0, 1), + Opcode::Dup14 => (0, 1), + Opcode::Dup15 => (0, 1), + Opcode::Dup16 => (0, 1), + Opcode::Swap1 => (0, 0), + Opcode::Swap2 => (0, 0), + Opcode::Swap3 => (0, 0), + Opcode::Swap4 => (0, 0), + Opcode::Swap5 => (0, 0), + Opcode::Swap6 => (0, 0), + Opcode::Swap7 => (0, 0), + Opcode::Swap8 => (0, 0), + Opcode::Swap9 => (0, 0), + Opcode::Swap10 => (0, 0), + Opcode::Swap11 => (0, 0), + Opcode::Swap12 => (0, 0), + Opcode::Swap13 => (0, 0), + Opcode::Swap14 => (0, 0), + Opcode::Swap15 => (0, 0), + Opcode::Swap16 => (0, 0), + Opcode::Log0 => (2, 0), + Opcode::Log1 => (3, 0), + Opcode::Log2 => (4, 0), + Opcode::Log3 => (5, 0), + Opcode::Log4 => (6, 0), + Opcode::TLoad => (1, 1), + Opcode::TStore => (2, 0), + Opcode::Create => (3, 1), + Opcode::Call => (7, 1), + Opcode::Callcode => (7, 1), + Opcode::Return => (2, 0), + Opcode::Delegatecall => (6, 1), + Opcode::Create2 => (4, 1), + Opcode::Staticcall => (6, 1), + Opcode::Revert => (2, 0), + Opcode::Invalid => (0, 0), + Opcode::Selfdestruct => (1, 0), + } + } } impl fmt::Display for Opcode {