diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 48e88c0e..f0b7fc80 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -4,6 +4,8 @@ #![forbid(unsafe_code)] #![forbid(where_clauses_object_safety)] +use std::collections::HashMap; + use huff_utils::{ ast::*, error::*, @@ -136,6 +138,8 @@ impl Parser { } } + validate_macros(&contract)?; + Ok(contract) } @@ -1279,3 +1283,93 @@ impl Parser { } } } + +/// Function used to evaluate macro statements. Returns number of elements taken from the stack and +/// returned to the stack +pub fn evaluate_macro( + macro_name: &str, + macros: &[MacroDefinition], + evaluated_macros: &mut HashMap, +) -> Result<(i16, i16), ParserError> { + if let Some(macro_takes_returns) = evaluated_macros.get(macro_name) { + return Ok(*macro_takes_returns) + } + + let contract_macro = macros.iter().find(|m| m.name.as_str() == macro_name).unwrap(); + let (body_statements_take, body_statements_return) = + contract_macro.statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match &st.ty { + StatementType::Literal(_) | + StatementType::Constant(_) | + StatementType::BuiltinFunctionCall(_) | + StatementType::ArgCall(_) => (0i8, 1i8), + StatementType::LabelCall(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + if opcode.is_value_push() { + (0i8, 0i8) + } else { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + } + StatementType::Label(_) => (0i8, 0i8), + StatementType::MacroInvocation(MacroInvocation { + macro_name, + args: _, + span: _, + }) => { + let (takes, returns) = + evaluate_macro(macro_name, macros, evaluated_macros).unwrap(); + (takes.abs() as i8, returns as i8) + } + StatementType::Code(_) => { + todo!("should throw error") + } + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { + (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) + } else { + (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) + }; + (stack_takes, stack_returns) + }); + + evaluated_macros + .insert(contract_macro.name.clone(), (body_statements_take, body_statements_return)); + Ok((body_statements_take, body_statements_return)) +} + +/// Function used to validate takes and returns of outlined macros in the contract +pub fn validate_macros(contract: &Contract) -> Result<(), ParserError> { + let mut evaluated_macros = HashMap::with_capacity(contract.macros.len()); + for contract_macro in contract.macros.iter().filter(|m| m.outlined) { + let (body_statements_take, body_statements_return) = + evaluate_macro(&contract_macro.name, &contract.macros, &mut evaluated_macros)?; + if body_statements_take.abs() != contract_macro.takes as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some(format!( + "Fn {} specified to take {} elements from the stack, but it takes {}", + contract_macro.name, + contract_macro.takes, + body_statements_take.abs() + )), + spans: contract_macro.span.clone(), + }) + } + if body_statements_return != contract_macro.returns as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some(format!( + "Fn {} specified to return {} elements to the stack, but it returns {}", + contract_macro.name, contract_macro.returns, body_statements_return + )), + spans: contract_macro.span.clone(), + }) + } + } + Ok(()) +} diff --git a/huff_parser/tests/macro.rs b/huff_parser/tests/macro.rs index bc34e903..a1539eff 100644 --- a/huff_parser/tests/macro.rs +++ b/huff_parser/tests/macro.rs @@ -874,7 +874,7 @@ fn macro_with_builtin_fn_call() { // difference besides the spans as well as the outlined flag. #[test] fn empty_outlined_macro() { - let source = "#define fn HELLO_WORLD() = takes(0) returns(4) {}"; + let source = "#define fn HELLO_WORLD() = takes(0) returns(0) {}"; let flattened_source = FullFileSource { source, file: None, spans: vec![] }; let lexer = Lexer::new(flattened_source.source); @@ -889,7 +889,7 @@ fn empty_outlined_macro() { parameters: vec![], statements: vec![], takes: 0, - returns: 4, + returns: 0, span: AstSpan(vec![ Span { start: 0, end: 6, file: None }, Span { start: 8, end: 9, file: None }, @@ -917,7 +917,7 @@ fn empty_outlined_macro() { #[test] fn outlined_macro_with_simple_body() { - let source = "#define fn HELLO_WORLD() = takes(3) returns(0) {\n0x00 mstore\n 0x01 0x02 add\n}"; + let source = "#define fn HELLO_WORLD() = takes(1) returns(1) {\n0x00 mstore\n 0x01 0x02 add\n}"; let flattened_source = FullFileSource { source, file: None, spans: vec![] }; let lexer = Lexer::new(flattened_source.source); let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); @@ -951,8 +951,8 @@ fn outlined_macro_with_simple_body() { span: AstSpan(vec![Span { start: 72, end: 74, file: None }]), }, ], - takes: 3, - returns: 0, + takes: 1, + returns: 1, span: AstSpan(vec![ Span { start: 0, end: 6, file: None }, Span { start: 8, end: 9, file: None }, @@ -983,6 +983,176 @@ fn outlined_macro_with_simple_body() { assert_eq!(parser.current_token.kind, TokenKind::Eof); } +#[test] +fn outlined_macro_revert_on_more_to_take() { + let source = "#define fn HELLO_WORLD() = takes(1) returns(0) {}"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some( + "Fn HELLO_WORLD specified to take 1 elements from the stack, but it takes 0" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 48, end: 48, file: None } + ]) + } + ) +} + +#[test] +fn outlined_macro_revert_on_more_to_return() { + let source = "#define fn HELLO_WORLD() = takes(0) returns(1) {}"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some( + "Fn HELLO_WORLD specified to return 1 elements to the stack, but it returns 0" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 48, end: 48, file: None } + ]) + } + ) +} + +#[test] +fn outlined_macro_revert_on_less_to_take() { + let source = "#define fn HELLO_WORLD() = takes(1) returns(0) { 0x01 add call }"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some( + "Fn HELLO_WORLD specified to take 1 elements from the stack, but it takes 7" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 51, end: 52, file: None }, + Span { start: 54, end: 56, file: None }, + Span { start: 58, end: 61, file: None }, + Span { start: 63, end: 63, file: None } + ]) + } + ) +} + +#[test] +fn outlined_macro_revert_on_less_to_return() { + let source = "#define fn HELLO_WORLD() = takes(0) returns(1) { 0x01 0x01 dup1 }"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some( + "Fn HELLO_WORLD specified to return 1 elements to the stack, but it returns 3" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 51, end: 52, file: None }, + Span { start: 56, end: 57, file: None }, + Span { start: 59, end: 62, file: None }, + Span { start: 64, end: 64, file: None } + ]) + } + ) +} + #[test] fn empty_test() { let source = "#define test HELLO_WORLD() = takes(0) returns(4) {}"; @@ -1028,7 +1198,7 @@ fn empty_test() { #[test] fn test_with_simple_body() { let source = - "#define test HELLO_WORLD() = takes(3) returns(0) {\n0x00 0x00 mstore\n 0x01 0x02 add\n}"; + "#define test HELLO_WORLD() = takes(0) returns(1) {\n0x00 0x00 mstore\n 0x01 0x02 add\n}"; let flattened_source = FullFileSource { source, file: None, spans: vec![] }; let lexer = Lexer::new(flattened_source.source); let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); @@ -1078,8 +1248,8 @@ fn test_with_simple_body() { span: AstSpan(vec![Span { start: 79, end: 81, file: None }]), }, ], - takes: 3, - returns: 0, + takes: 0, + returns: 1, span: AstSpan(vec![ Span { start: 0, end: 6, file: None }, Span { start: 8, end: 11, file: None }, diff --git a/huff_utils/src/error.rs b/huff_utils/src/error.rs index 9ecb2631..a7e635fc 100644 --- a/huff_utils/src/error.rs +++ b/huff_utils/src/error.rs @@ -63,6 +63,8 @@ pub enum ParserErrorKind { InvalidDecoratorFlag(String), /// Invalid decorator flag argument InvalidDecoratorFlagArg(TokenKind), + /// Invalid stack annotation + InvalidStackAnnotation(TokenKind), } /// A Lexing Error @@ -488,6 +490,14 @@ impl fmt::Display for CompilerError { pe.spans.error(pe.hint.as_ref()) ) } + ParserErrorKind::InvalidStackAnnotation(rt) => { + write!( + f, + "\nError: Invalid stack {} annotation in function definition \n{}\n", + rt, + pe.spans.error(pe.hint.as_ref()) + ) + } }, CompilerError::PathBufRead(os_str) => { write!( diff --git a/huff_utils/src/evm.rs b/huff_utils/src/evm.rs index 8ad61a40..96bdc276 100644 --- a/huff_utils/src/evm.rs +++ b/huff_utils/src/evm.rs @@ -836,6 +836,146 @@ impl Opcode { false } + + /// Returns stack changes by the given opcode + pub fn stack_changes(&self) -> (u8, u8) { + match self { + Opcode::Stop => (0, 0), + Opcode::Add | Opcode::Mul | Opcode::Sub | Opcode::Div => (2, 1), + Opcode::Sdiv | Opcode::Mod | Opcode::Smod => (2, 1), + Opcode::Addmod | Opcode::Mulmod => (3, 1), + Opcode::Exp => (2, 1), + Opcode::Signextend => (2, 1), + Opcode::Lt | Opcode::Gt => (2, 1), + Opcode::Slt | Opcode::Sgt => (2, 1), + Opcode::Eq => (2, 1), + Opcode::Iszero => (1, 1), + Opcode::And | Opcode::Or | Opcode::Xor => (2, 1), + Opcode::Not => (1, 1), + Opcode::Byte => (2, 1), + Opcode::Shl | Opcode::Shr | Opcode::Sar => (2, 1), + Opcode::Sha3 => (2, 1), + Opcode::Address => (0, 1), + Opcode::Balance => (1, 1), + Opcode::Origin => (0, 1), + Opcode::Caller => (0, 1), + Opcode::Callvalue => (0, 1), + Opcode::Calldataload => (1, 1), + Opcode::Calldatasize => (0, 1), + Opcode::Calldatacopy => (3, 0), + Opcode::Codesize => (0, 1), + Opcode::Codecopy => (3, 0), + Opcode::Gasprice => (0, 1), + Opcode::Extcodesize => (1, 1), + Opcode::Extcodecopy => (4, 0), + Opcode::Returndatasize => (0, 1), + Opcode::Returndatacopy => (3, 0), + Opcode::Extcodehash => (1, 1), + Opcode::Blockhash => (1, 1), + Opcode::Coinbase => (0, 1), + Opcode::Timestamp => (0, 1), + Opcode::Number => (0, 1), + Opcode::Difficulty => (0, 1), + Opcode::Prevrandao => (0, 1), + Opcode::Gaslimit => (0, 1), + Opcode::Chainid => (0, 1), + Opcode::Selfbalance => (0, 1), + Opcode::Basefee => (0, 1), + Opcode::Pop => (1, 0), + Opcode::Mload => (1, 1), + Opcode::Mstore | Opcode::Mstore8 => (2, 0), + Opcode::Sload => (1, 1), + Opcode::Sstore => (2, 0), + Opcode::Jump => (1, 0), + Opcode::Jumpi => (2, 0), + Opcode::Pc => (0, 1), + Opcode::Msize => (0, 1), + Opcode::Gas => (0, 1), + Opcode::Jumpdest => (0, 0), + Opcode::Push0 => (0, 1), + Opcode::Push1 => (0, 1), + Opcode::Push2 => (0, 1), + Opcode::Push3 => (0, 1), + Opcode::Push4 => (0, 1), + Opcode::Push5 => (0, 1), + Opcode::Push6 => (0, 1), + Opcode::Push7 => (0, 1), + Opcode::Push8 => (0, 1), + Opcode::Push9 => (0, 1), + Opcode::Push10 => (0, 1), + Opcode::Push11 => (0, 1), + Opcode::Push12 => (0, 1), + Opcode::Push13 => (0, 1), + Opcode::Push14 => (0, 1), + Opcode::Push15 => (0, 1), + Opcode::Push16 => (0, 1), + Opcode::Push17 => (0, 1), + Opcode::Push18 => (0, 1), + Opcode::Push19 => (0, 1), + Opcode::Push20 => (0, 1), + Opcode::Push21 => (0, 1), + Opcode::Push22 => (0, 1), + Opcode::Push23 => (0, 1), + Opcode::Push24 => (0, 1), + Opcode::Push25 => (0, 1), + Opcode::Push26 => (0, 1), + Opcode::Push27 => (0, 1), + Opcode::Push28 => (0, 1), + Opcode::Push29 => (0, 1), + Opcode::Push30 => (0, 1), + Opcode::Push31 => (0, 1), + Opcode::Push32 => (0, 1), + Opcode::Dup1 => (1, 2), + Opcode::Dup2 => (2, 3), + Opcode::Dup3 => (3, 4), + Opcode::Dup4 => (4, 5), + Opcode::Dup5 => (5, 6), + Opcode::Dup6 => (6, 7), + Opcode::Dup7 => (7, 8), + Opcode::Dup8 => (8, 9), + Opcode::Dup9 => (9, 10), + Opcode::Dup10 => (10, 11), + Opcode::Dup11 => (11, 12), + Opcode::Dup12 => (12, 13), + Opcode::Dup13 => (13, 14), + Opcode::Dup14 => (14, 15), + Opcode::Dup15 => (15, 16), + Opcode::Dup16 => (16, 17), + Opcode::Swap1 => (2, 2), + Opcode::Swap2 => (3, 3), + Opcode::Swap3 => (4, 4), + Opcode::Swap4 => (5, 5), + Opcode::Swap5 => (6, 6), + Opcode::Swap6 => (7, 7), + Opcode::Swap7 => (8, 8), + Opcode::Swap8 => (9, 9), + Opcode::Swap9 => (10, 10), + Opcode::Swap10 => (11, 11), + Opcode::Swap11 => (12, 12), + Opcode::Swap12 => (13, 13), + Opcode::Swap13 => (14, 14), + Opcode::Swap14 => (15, 15), + Opcode::Swap15 => (16, 16), + Opcode::Swap16 => (17, 17), + Opcode::Log0 => (2, 0), + Opcode::Log1 => (3, 0), + Opcode::Log2 => (4, 0), + Opcode::Log3 => (5, 0), + Opcode::Log4 => (6, 0), + Opcode::TLoad => (1, 1), + Opcode::TStore => (2, 0), + Opcode::Create => (3, 1), + Opcode::Call => (7, 1), + Opcode::Callcode => (7, 1), + Opcode::Return => (2, 0), + Opcode::Delegatecall => (6, 1), + Opcode::Create2 => (4, 1), + Opcode::Staticcall => (6, 1), + Opcode::Revert => (2, 0), + Opcode::Invalid => (0, 0), + Opcode::Selfdestruct => (1, 0), + } + } } impl fmt::Display for Opcode {