From c562b98524703bdcb83135c41f903d57556937d4 Mon Sep 17 00:00:00 2001 From: Igor Line Date: Sun, 1 Oct 2023 15:05:46 +0200 Subject: [PATCH 1/5] feat: check real stack changes inside fn vs its definition --- huff_parser/src/lib.rs | 45 +++++++++++++ huff_utils/src/error.rs | 10 +++ huff_utils/src/evm.rs | 140 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 195 insertions(+) diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 48e88c0e..1d2c59c4 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -521,6 +521,51 @@ impl Parser { let macro_statements: Vec = self.parse_body()?; + let (body_statements_take, body_statements_return) = + macro_statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match st.ty { + StatementType::Literal(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + _ => (0i8, 0i8), + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let stack_takes = acc.0 + acc.1 - statement_takes as i16; + let stack_returns = if statement_takes as i16 > acc.1 { + statement_returns as i16 + } else { + acc.1 - statement_takes as i16 + statement_returns as i16 + }; + (stack_takes, stack_returns) + }); + + if outlined { + if body_statements_take.abs() != macro_takes as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some(format!( + "Fn {macro_name} specified to take {macro_takes} elements from the stack, but it takes {}", + body_statements_take.abs() + )), + spans: AstSpan(self.spans.clone()), + }); + } + if body_statements_return != macro_returns as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some(format!( + "Fn {macro_name} specified to return {macro_returns} elements to the stack, but it returns {}", + body_statements_return + )), + spans: AstSpan(self.spans.clone()), + }); + } + } + Ok(MacroDefinition::new( macro_name, decorator, diff --git a/huff_utils/src/error.rs b/huff_utils/src/error.rs index 9ecb2631..a7e635fc 100644 --- a/huff_utils/src/error.rs +++ b/huff_utils/src/error.rs @@ -63,6 +63,8 @@ pub enum ParserErrorKind { InvalidDecoratorFlag(String), /// Invalid decorator flag argument InvalidDecoratorFlagArg(TokenKind), + /// Invalid stack annotation + InvalidStackAnnotation(TokenKind), } /// A Lexing Error @@ -488,6 +490,14 @@ impl fmt::Display for CompilerError { pe.spans.error(pe.hint.as_ref()) ) } + ParserErrorKind::InvalidStackAnnotation(rt) => { + write!( + f, + "\nError: Invalid stack {} annotation in function definition \n{}\n", + rt, + pe.spans.error(pe.hint.as_ref()) + ) + } }, CompilerError::PathBufRead(os_str) => { write!( diff --git a/huff_utils/src/evm.rs b/huff_utils/src/evm.rs index 8ad61a40..f485602c 100644 --- a/huff_utils/src/evm.rs +++ b/huff_utils/src/evm.rs @@ -836,6 +836,146 @@ impl Opcode { false } + + /// Returns stack changes by the given opcode + pub fn stack_changes(&self) -> (u8, u8) { + match self { + Opcode::Stop => (0, 0), + Opcode::Add | Opcode::Mul | Opcode::Sub | Opcode::Div => (2, 1), + Opcode::Sdiv | Opcode::Mod | Opcode::Smod => (2, 1), + Opcode::Addmod | Opcode::Mulmod => (3, 1), + Opcode::Exp => (2, 1), + Opcode::Signextend => (2, 1), + Opcode::Lt | Opcode::Gt => (2, 1), + Opcode::Slt | Opcode::Sgt => (2, 1), + Opcode::Eq => (2, 1), + Opcode::Iszero => (1, 1), + Opcode::And | Opcode::Or | Opcode::Xor => (2, 1), + Opcode::Not => (1, 1), + Opcode::Byte => (2, 1), + Opcode::Shl | Opcode::Shr | Opcode::Sar => (2, 1), + Opcode::Sha3 => (2, 1), + Opcode::Address => (0, 1), + Opcode::Balance => (1, 1), + Opcode::Origin => (0, 1), + Opcode::Caller => (0, 1), + Opcode::Callvalue => (0, 1), + Opcode::Calldataload => (1, 1), + Opcode::Calldatasize => (0, 1), + Opcode::Calldatacopy => (3, 0), + Opcode::Codesize => (0, 1), + Opcode::Codecopy => (3, 0), + Opcode::Gasprice => (0, 1), + Opcode::Extcodesize => (1, 1), + Opcode::Extcodecopy => (4, 0), + Opcode::Returndatasize => (0, 1), + Opcode::Returndatacopy => (3, 0), + Opcode::Extcodehash => (1, 1), + Opcode::Blockhash => (1, 1), + Opcode::Coinbase => (0, 1), + Opcode::Timestamp => (0, 1), + Opcode::Number => (0, 1), + Opcode::Difficulty => (0, 1), + Opcode::Prevrandao => (0, 1), + Opcode::Gaslimit => (0, 1), + Opcode::Chainid => (0, 1), + Opcode::Selfbalance => (0, 1), + Opcode::Basefee => (0, 1), + Opcode::Pop => (1, 0), + Opcode::Mload => (1, 1), + Opcode::Mstore | Opcode::Mstore8 => (2, 0), + Opcode::Sload => (1, 1), + Opcode::Sstore => (2, 0), + Opcode::Jump => (1, 0), + Opcode::Jumpi => (2, 0), + Opcode::Pc => (0, 1), + Opcode::Msize => (0, 1), + Opcode::Gas => (0, 1), + Opcode::Jumpdest => (0, 0), + Opcode::Push0 => (0, 1), + Opcode::Push1 => (0, 1), + Opcode::Push2 => (0, 1), + Opcode::Push3 => (0, 1), + Opcode::Push4 => (0, 1), + Opcode::Push5 => (0, 1), + Opcode::Push6 => (0, 1), + Opcode::Push7 => (0, 1), + Opcode::Push8 => (0, 1), + Opcode::Push9 => (0, 1), + Opcode::Push10 => (0, 1), + Opcode::Push11 => (0, 1), + Opcode::Push12 => (0, 1), + Opcode::Push13 => (0, 1), + Opcode::Push14 => (0, 1), + Opcode::Push15 => (0, 1), + Opcode::Push16 => (0, 1), + Opcode::Push17 => (0, 1), + Opcode::Push18 => (0, 1), + Opcode::Push19 => (0, 1), + Opcode::Push20 => (0, 1), + Opcode::Push21 => (0, 1), + Opcode::Push22 => (0, 1), + Opcode::Push23 => (0, 1), + Opcode::Push24 => (0, 1), + Opcode::Push25 => (0, 1), + Opcode::Push26 => (0, 1), + Opcode::Push27 => (0, 1), + Opcode::Push28 => (0, 1), + Opcode::Push29 => (0, 1), + Opcode::Push30 => (0, 1), + Opcode::Push31 => (0, 1), + Opcode::Push32 => (0, 1), + Opcode::Dup1 => (0, 1), + Opcode::Dup2 => (0, 1), + Opcode::Dup3 => (0, 1), + Opcode::Dup4 => (0, 1), + Opcode::Dup5 => (0, 1), + Opcode::Dup6 => (0, 1), + Opcode::Dup7 => (0, 1), + Opcode::Dup8 => (0, 1), + Opcode::Dup9 => (0, 1), + Opcode::Dup10 => (0, 1), + Opcode::Dup11 => (0, 1), + Opcode::Dup12 => (0, 1), + Opcode::Dup13 => (0, 1), + Opcode::Dup14 => (0, 1), + Opcode::Dup15 => (0, 1), + Opcode::Dup16 => (0, 1), + Opcode::Swap1 => (0, 0), + Opcode::Swap2 => (0, 0), + Opcode::Swap3 => (0, 0), + Opcode::Swap4 => (0, 0), + Opcode::Swap5 => (0, 0), + Opcode::Swap6 => (0, 0), + Opcode::Swap7 => (0, 0), + Opcode::Swap8 => (0, 0), + Opcode::Swap9 => (0, 0), + Opcode::Swap10 => (0, 0), + Opcode::Swap11 => (0, 0), + Opcode::Swap12 => (0, 0), + Opcode::Swap13 => (0, 0), + Opcode::Swap14 => (0, 0), + Opcode::Swap15 => (0, 0), + Opcode::Swap16 => (0, 0), + Opcode::Log0 => (2, 0), + Opcode::Log1 => (3, 0), + Opcode::Log2 => (4, 0), + Opcode::Log3 => (5, 0), + Opcode::Log4 => (6, 0), + Opcode::TLoad => (1, 1), + Opcode::TStore => (2, 0), + Opcode::Create => (3, 1), + Opcode::Call => (7, 1), + Opcode::Callcode => (7, 1), + Opcode::Return => (2, 0), + Opcode::Delegatecall => (6, 1), + Opcode::Create2 => (4, 1), + Opcode::Staticcall => (6, 1), + Opcode::Revert => (2, 0), + Opcode::Invalid => (0, 0), + Opcode::Selfdestruct => (1, 0), + } + } } impl fmt::Display for Opcode { From ce626039213b905c1d792a90de33ea89feb9b46d Mon Sep 17 00:00:00 2001 From: Igor Line Date: Tue, 3 Oct 2023 23:52:10 +0200 Subject: [PATCH 2/5] fix: add tests and update logic --- huff_parser/src/lib.rs | 7 +- huff_parser/tests/macro.rs | 186 +++++++++++++++++++++++++++++++++++-- 2 files changed, 181 insertions(+), 12 deletions(-) diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 1d2c59c4..3791c563 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -534,11 +534,10 @@ impl Parser { // acc.1 is always non negative // acc.0 is always non positive - let stack_takes = acc.0 + acc.1 - statement_takes as i16; - let stack_returns = if statement_takes as i16 > acc.1 { - statement_returns as i16 + let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { + (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) } else { - acc.1 - statement_takes as i16 + statement_returns as i16 + (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) }; (stack_takes, stack_returns) }); diff --git a/huff_parser/tests/macro.rs b/huff_parser/tests/macro.rs index bc34e903..a1539eff 100644 --- a/huff_parser/tests/macro.rs +++ b/huff_parser/tests/macro.rs @@ -874,7 +874,7 @@ fn macro_with_builtin_fn_call() { // difference besides the spans as well as the outlined flag. #[test] fn empty_outlined_macro() { - let source = "#define fn HELLO_WORLD() = takes(0) returns(4) {}"; + let source = "#define fn HELLO_WORLD() = takes(0) returns(0) {}"; let flattened_source = FullFileSource { source, file: None, spans: vec![] }; let lexer = Lexer::new(flattened_source.source); @@ -889,7 +889,7 @@ fn empty_outlined_macro() { parameters: vec![], statements: vec![], takes: 0, - returns: 4, + returns: 0, span: AstSpan(vec![ Span { start: 0, end: 6, file: None }, Span { start: 8, end: 9, file: None }, @@ -917,7 +917,7 @@ fn empty_outlined_macro() { #[test] fn outlined_macro_with_simple_body() { - let source = "#define fn HELLO_WORLD() = takes(3) returns(0) {\n0x00 mstore\n 0x01 0x02 add\n}"; + let source = "#define fn HELLO_WORLD() = takes(1) returns(1) {\n0x00 mstore\n 0x01 0x02 add\n}"; let flattened_source = FullFileSource { source, file: None, spans: vec![] }; let lexer = Lexer::new(flattened_source.source); let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); @@ -951,8 +951,8 @@ fn outlined_macro_with_simple_body() { span: AstSpan(vec![Span { start: 72, end: 74, file: None }]), }, ], - takes: 3, - returns: 0, + takes: 1, + returns: 1, span: AstSpan(vec![ Span { start: 0, end: 6, file: None }, Span { start: 8, end: 9, file: None }, @@ -983,6 +983,176 @@ fn outlined_macro_with_simple_body() { assert_eq!(parser.current_token.kind, TokenKind::Eof); } +#[test] +fn outlined_macro_revert_on_more_to_take() { + let source = "#define fn HELLO_WORLD() = takes(1) returns(0) {}"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some( + "Fn HELLO_WORLD specified to take 1 elements from the stack, but it takes 0" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 48, end: 48, file: None } + ]) + } + ) +} + +#[test] +fn outlined_macro_revert_on_more_to_return() { + let source = "#define fn HELLO_WORLD() = takes(0) returns(1) {}"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some( + "Fn HELLO_WORLD specified to return 1 elements to the stack, but it returns 0" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 48, end: 48, file: None } + ]) + } + ) +} + +#[test] +fn outlined_macro_revert_on_less_to_take() { + let source = "#define fn HELLO_WORLD() = takes(1) returns(0) { 0x01 add call }"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some( + "Fn HELLO_WORLD specified to take 1 elements from the stack, but it takes 7" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 51, end: 52, file: None }, + Span { start: 54, end: 56, file: None }, + Span { start: 58, end: 61, file: None }, + Span { start: 63, end: 63, file: None } + ]) + } + ) +} + +#[test] +fn outlined_macro_revert_on_less_to_return() { + let source = "#define fn HELLO_WORLD() = takes(0) returns(1) { 0x01 0x01 dup1 }"; + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + + // Grab the first macro + let expected_error = parser.parse().unwrap_err(); + + assert_eq!( + expected_error, + ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some( + "Fn HELLO_WORLD specified to return 1 elements to the stack, but it returns 3" + .to_string() + ), + spans: AstSpan(vec![ + Span { start: 0, end: 6, file: None }, + Span { start: 8, end: 9, file: None }, + Span { start: 11, end: 21, file: None }, + Span { start: 22, end: 22, file: None }, + Span { start: 23, end: 23, file: None }, + Span { start: 25, end: 25, file: None }, + Span { start: 27, end: 31, file: None }, + Span { start: 32, end: 32, file: None }, + Span { start: 33, end: 33, file: None }, + Span { start: 34, end: 34, file: None }, + Span { start: 36, end: 42, file: None }, + Span { start: 43, end: 43, file: None }, + Span { start: 44, end: 44, file: None }, + Span { start: 45, end: 45, file: None }, + Span { start: 47, end: 47, file: None }, + Span { start: 51, end: 52, file: None }, + Span { start: 56, end: 57, file: None }, + Span { start: 59, end: 62, file: None }, + Span { start: 64, end: 64, file: None } + ]) + } + ) +} + #[test] fn empty_test() { let source = "#define test HELLO_WORLD() = takes(0) returns(4) {}"; @@ -1028,7 +1198,7 @@ fn empty_test() { #[test] fn test_with_simple_body() { let source = - "#define test HELLO_WORLD() = takes(3) returns(0) {\n0x00 0x00 mstore\n 0x01 0x02 add\n}"; + "#define test HELLO_WORLD() = takes(0) returns(1) {\n0x00 0x00 mstore\n 0x01 0x02 add\n}"; let flattened_source = FullFileSource { source, file: None, spans: vec![] }; let lexer = Lexer::new(flattened_source.source); let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); @@ -1078,8 +1248,8 @@ fn test_with_simple_body() { span: AstSpan(vec![Span { start: 79, end: 81, file: None }]), }, ], - takes: 3, - returns: 0, + takes: 0, + returns: 1, span: AstSpan(vec![ Span { start: 0, end: 6, file: None }, Span { start: 8, end: 11, file: None }, From 94f5bc61b9a061c39f1a26e370e7fed236643c58 Mon Sep 17 00:00:00 2001 From: Igor Line Date: Thu, 5 Oct 2023 13:36:54 +0200 Subject: [PATCH 3/5] fix: evaluate body statement stack changes only for outlined macros --- huff_parser/src/lib.rs | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 3791c563..d991eb65 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -521,28 +521,27 @@ impl Parser { let macro_statements: Vec = self.parse_body()?; - let (body_statements_take, body_statements_return) = - macro_statements.iter().fold((0i16, 0i16), |acc, st| { - let (statement_takes, statement_returns) = match st.ty { - StatementType::Literal(_) => (0i8, 1i8), - StatementType::Opcode(opcode) => { - let stack_changes = opcode.stack_changes(); - (stack_changes.0 as i8, stack_changes.1 as i8) - } - _ => (0i8, 0i8), - }; - - // acc.1 is always non negative - // acc.0 is always non positive - let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { - (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) - } else { - (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) - }; - (stack_takes, stack_returns) - }); - if outlined { + let (body_statements_take, body_statements_return) = + macro_statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match st.ty { + StatementType::Literal(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + _ => (0i8, 0i8), + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { + (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) + } else { + (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) + }; + (stack_takes, stack_returns) + }); if body_statements_take.abs() != macro_takes as i16 { return Err(ParserError { kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), From 9dc6f61bd77af7f624febe5a5ec4bbd6bee2b34a Mon Sep 17 00:00:00 2001 From: Igor Line Date: Thu, 5 Oct 2023 14:23:13 +0200 Subject: [PATCH 4/5] feat: handle other statement types and pushes/ --- huff_parser/src/lib.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index d991eb65..387874b8 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -525,12 +525,26 @@ impl Parser { let (body_statements_take, body_statements_return) = macro_statements.iter().fold((0i16, 0i16), |acc, st| { let (statement_takes, statement_returns) = match st.ty { - StatementType::Literal(_) => (0i8, 1i8), + StatementType::Literal(_) | + StatementType::Constant(_) | + StatementType::BuiltinFunctionCall(_) | + StatementType::ArgCall(_) | + StatementType::LabelCall(_) => (0i8, 1i8), StatementType::Opcode(opcode) => { - let stack_changes = opcode.stack_changes(); - (stack_changes.0 as i8, stack_changes.1 as i8) + if opcode.is_value_push() { + (0i8, 0i8) + } else { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + } + StatementType::Label(_) => (0i8, 0i8), + StatementType::MacroInvocation(_) => { + todo!() + } + StatementType::Code(_) => { + todo!("should throw error") } - _ => (0i8, 0i8), }; // acc.1 is always non negative From 6c32e661981c0f770b4c07b27270918706aec480 Mon Sep 17 00:00:00 2001 From: Igor Line Date: Mon, 9 Oct 2023 20:49:36 +0200 Subject: [PATCH 5/5] feat: add check for macro invocations before returning contract parse --- huff_parser/src/lib.rs | 151 +++++++++++++++++++++++++---------------- huff_utils/src/evm.rs | 64 ++++++++--------- 2 files changed, 126 insertions(+), 89 deletions(-) diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 387874b8..f0b7fc80 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -4,6 +4,8 @@ #![forbid(unsafe_code)] #![forbid(where_clauses_object_safety)] +use std::collections::HashMap; + use huff_utils::{ ast::*, error::*, @@ -136,6 +138,8 @@ impl Parser { } } + validate_macros(&contract)?; + Ok(contract) } @@ -521,63 +525,6 @@ impl Parser { let macro_statements: Vec = self.parse_body()?; - if outlined { - let (body_statements_take, body_statements_return) = - macro_statements.iter().fold((0i16, 0i16), |acc, st| { - let (statement_takes, statement_returns) = match st.ty { - StatementType::Literal(_) | - StatementType::Constant(_) | - StatementType::BuiltinFunctionCall(_) | - StatementType::ArgCall(_) | - StatementType::LabelCall(_) => (0i8, 1i8), - StatementType::Opcode(opcode) => { - if opcode.is_value_push() { - (0i8, 0i8) - } else { - let stack_changes = opcode.stack_changes(); - (stack_changes.0 as i8, stack_changes.1 as i8) - } - } - StatementType::Label(_) => (0i8, 0i8), - StatementType::MacroInvocation(_) => { - todo!() - } - StatementType::Code(_) => { - todo!("should throw error") - } - }; - - // acc.1 is always non negative - // acc.0 is always non positive - let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { - (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) - } else { - (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) - }; - (stack_takes, stack_returns) - }); - if body_statements_take.abs() != macro_takes as i16 { - return Err(ParserError { - kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), - hint: Some(format!( - "Fn {macro_name} specified to take {macro_takes} elements from the stack, but it takes {}", - body_statements_take.abs() - )), - spans: AstSpan(self.spans.clone()), - }); - } - if body_statements_return != macro_returns as i16 { - return Err(ParserError { - kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), - hint: Some(format!( - "Fn {macro_name} specified to return {macro_returns} elements to the stack, but it returns {}", - body_statements_return - )), - spans: AstSpan(self.spans.clone()), - }); - } - } - Ok(MacroDefinition::new( macro_name, decorator, @@ -1336,3 +1283,93 @@ impl Parser { } } } + +/// Function used to evaluate macro statements. Returns number of elements taken from the stack and +/// returned to the stack +pub fn evaluate_macro( + macro_name: &str, + macros: &[MacroDefinition], + evaluated_macros: &mut HashMap, +) -> Result<(i16, i16), ParserError> { + if let Some(macro_takes_returns) = evaluated_macros.get(macro_name) { + return Ok(*macro_takes_returns) + } + + let contract_macro = macros.iter().find(|m| m.name.as_str() == macro_name).unwrap(); + let (body_statements_take, body_statements_return) = + contract_macro.statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match &st.ty { + StatementType::Literal(_) | + StatementType::Constant(_) | + StatementType::BuiltinFunctionCall(_) | + StatementType::ArgCall(_) => (0i8, 1i8), + StatementType::LabelCall(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + if opcode.is_value_push() { + (0i8, 0i8) + } else { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + } + StatementType::Label(_) => (0i8, 0i8), + StatementType::MacroInvocation(MacroInvocation { + macro_name, + args: _, + span: _, + }) => { + let (takes, returns) = + evaluate_macro(macro_name, macros, evaluated_macros).unwrap(); + (takes.abs() as i8, returns as i8) + } + StatementType::Code(_) => { + todo!("should throw error") + } + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { + (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) + } else { + (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) + }; + (stack_takes, stack_returns) + }); + + evaluated_macros + .insert(contract_macro.name.clone(), (body_statements_take, body_statements_return)); + Ok((body_statements_take, body_statements_return)) +} + +/// Function used to validate takes and returns of outlined macros in the contract +pub fn validate_macros(contract: &Contract) -> Result<(), ParserError> { + let mut evaluated_macros = HashMap::with_capacity(contract.macros.len()); + for contract_macro in contract.macros.iter().filter(|m| m.outlined) { + let (body_statements_take, body_statements_return) = + evaluate_macro(&contract_macro.name, &contract.macros, &mut evaluated_macros)?; + if body_statements_take.abs() != contract_macro.takes as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some(format!( + "Fn {} specified to take {} elements from the stack, but it takes {}", + contract_macro.name, + contract_macro.takes, + body_statements_take.abs() + )), + spans: contract_macro.span.clone(), + }) + } + if body_statements_return != contract_macro.returns as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some(format!( + "Fn {} specified to return {} elements to the stack, but it returns {}", + contract_macro.name, contract_macro.returns, body_statements_return + )), + spans: contract_macro.span.clone(), + }) + } + } + Ok(()) +} diff --git a/huff_utils/src/evm.rs b/huff_utils/src/evm.rs index f485602c..96bdc276 100644 --- a/huff_utils/src/evm.rs +++ b/huff_utils/src/evm.rs @@ -925,38 +925,38 @@ impl Opcode { Opcode::Push30 => (0, 1), Opcode::Push31 => (0, 1), Opcode::Push32 => (0, 1), - Opcode::Dup1 => (0, 1), - Opcode::Dup2 => (0, 1), - Opcode::Dup3 => (0, 1), - Opcode::Dup4 => (0, 1), - Opcode::Dup5 => (0, 1), - Opcode::Dup6 => (0, 1), - Opcode::Dup7 => (0, 1), - Opcode::Dup8 => (0, 1), - Opcode::Dup9 => (0, 1), - Opcode::Dup10 => (0, 1), - Opcode::Dup11 => (0, 1), - Opcode::Dup12 => (0, 1), - Opcode::Dup13 => (0, 1), - Opcode::Dup14 => (0, 1), - Opcode::Dup15 => (0, 1), - Opcode::Dup16 => (0, 1), - Opcode::Swap1 => (0, 0), - Opcode::Swap2 => (0, 0), - Opcode::Swap3 => (0, 0), - Opcode::Swap4 => (0, 0), - Opcode::Swap5 => (0, 0), - Opcode::Swap6 => (0, 0), - Opcode::Swap7 => (0, 0), - Opcode::Swap8 => (0, 0), - Opcode::Swap9 => (0, 0), - Opcode::Swap10 => (0, 0), - Opcode::Swap11 => (0, 0), - Opcode::Swap12 => (0, 0), - Opcode::Swap13 => (0, 0), - Opcode::Swap14 => (0, 0), - Opcode::Swap15 => (0, 0), - Opcode::Swap16 => (0, 0), + Opcode::Dup1 => (1, 2), + Opcode::Dup2 => (2, 3), + Opcode::Dup3 => (3, 4), + Opcode::Dup4 => (4, 5), + Opcode::Dup5 => (5, 6), + Opcode::Dup6 => (6, 7), + Opcode::Dup7 => (7, 8), + Opcode::Dup8 => (8, 9), + Opcode::Dup9 => (9, 10), + Opcode::Dup10 => (10, 11), + Opcode::Dup11 => (11, 12), + Opcode::Dup12 => (12, 13), + Opcode::Dup13 => (13, 14), + Opcode::Dup14 => (14, 15), + Opcode::Dup15 => (15, 16), + Opcode::Dup16 => (16, 17), + Opcode::Swap1 => (2, 2), + Opcode::Swap2 => (3, 3), + Opcode::Swap3 => (4, 4), + Opcode::Swap4 => (5, 5), + Opcode::Swap5 => (6, 6), + Opcode::Swap6 => (7, 7), + Opcode::Swap7 => (8, 8), + Opcode::Swap8 => (9, 9), + Opcode::Swap9 => (10, 10), + Opcode::Swap10 => (11, 11), + Opcode::Swap11 => (12, 12), + Opcode::Swap12 => (13, 13), + Opcode::Swap13 => (14, 14), + Opcode::Swap14 => (15, 15), + Opcode::Swap15 => (16, 16), + Opcode::Swap16 => (17, 17), Opcode::Log0 => (2, 0), Opcode::Log1 => (3, 0), Opcode::Log2 => (4, 0),