diff --git a/huff_parser/src/lib.rs b/huff_parser/src/lib.rs index 387874b8..f0b7fc80 100644 --- a/huff_parser/src/lib.rs +++ b/huff_parser/src/lib.rs @@ -4,6 +4,8 @@ #![forbid(unsafe_code)] #![forbid(where_clauses_object_safety)] +use std::collections::HashMap; + use huff_utils::{ ast::*, error::*, @@ -136,6 +138,8 @@ impl Parser { } } + validate_macros(&contract)?; + Ok(contract) } @@ -521,63 +525,6 @@ impl Parser { let macro_statements: Vec = self.parse_body()?; - if outlined { - let (body_statements_take, body_statements_return) = - macro_statements.iter().fold((0i16, 0i16), |acc, st| { - let (statement_takes, statement_returns) = match st.ty { - StatementType::Literal(_) | - StatementType::Constant(_) | - StatementType::BuiltinFunctionCall(_) | - StatementType::ArgCall(_) | - StatementType::LabelCall(_) => (0i8, 1i8), - StatementType::Opcode(opcode) => { - if opcode.is_value_push() { - (0i8, 0i8) - } else { - let stack_changes = opcode.stack_changes(); - (stack_changes.0 as i8, stack_changes.1 as i8) - } - } - StatementType::Label(_) => (0i8, 0i8), - StatementType::MacroInvocation(_) => { - todo!() - } - StatementType::Code(_) => { - todo!("should throw error") - } - }; - - // acc.1 is always non negative - // acc.0 is always non positive - let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { - (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) - } else { - (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) - }; - (stack_takes, stack_returns) - }); - if body_statements_take.abs() != macro_takes as i16 { - return Err(ParserError { - kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), - hint: Some(format!( - "Fn {macro_name} specified to take {macro_takes} elements from the stack, but it takes {}", - body_statements_take.abs() - )), - spans: AstSpan(self.spans.clone()), - }); - } - if body_statements_return != macro_returns as i16 { - return Err(ParserError { - kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), - hint: Some(format!( - "Fn {macro_name} specified to return {macro_returns} elements to the stack, but it returns {}", - body_statements_return - )), - spans: AstSpan(self.spans.clone()), - }); - } - } - Ok(MacroDefinition::new( macro_name, decorator, @@ -1336,3 +1283,93 @@ impl Parser { } } } + +/// Function used to evaluate macro statements. Returns number of elements taken from the stack and +/// returned to the stack +pub fn evaluate_macro( + macro_name: &str, + macros: &[MacroDefinition], + evaluated_macros: &mut HashMap, +) -> Result<(i16, i16), ParserError> { + if let Some(macro_takes_returns) = evaluated_macros.get(macro_name) { + return Ok(*macro_takes_returns) + } + + let contract_macro = macros.iter().find(|m| m.name.as_str() == macro_name).unwrap(); + let (body_statements_take, body_statements_return) = + contract_macro.statements.iter().fold((0i16, 0i16), |acc, st| { + let (statement_takes, statement_returns) = match &st.ty { + StatementType::Literal(_) | + StatementType::Constant(_) | + StatementType::BuiltinFunctionCall(_) | + StatementType::ArgCall(_) => (0i8, 1i8), + StatementType::LabelCall(_) => (0i8, 1i8), + StatementType::Opcode(opcode) => { + if opcode.is_value_push() { + (0i8, 0i8) + } else { + let stack_changes = opcode.stack_changes(); + (stack_changes.0 as i8, stack_changes.1 as i8) + } + } + StatementType::Label(_) => (0i8, 0i8), + StatementType::MacroInvocation(MacroInvocation { + macro_name, + args: _, + span: _, + }) => { + let (takes, returns) = + evaluate_macro(macro_name, macros, evaluated_macros).unwrap(); + (takes.abs() as i8, returns as i8) + } + StatementType::Code(_) => { + todo!("should throw error") + } + }; + + // acc.1 is always non negative + // acc.0 is always non positive + let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 { + (acc.0 + acc.1 - statement_takes as i16, statement_returns as i16) + } else { + (acc.0, acc.1 - statement_takes as i16 + statement_returns as i16) + }; + (stack_takes, stack_returns) + }); + + evaluated_macros + .insert(contract_macro.name.clone(), (body_statements_take, body_statements_return)); + Ok((body_statements_take, body_statements_return)) +} + +/// Function used to validate takes and returns of outlined macros in the contract +pub fn validate_macros(contract: &Contract) -> Result<(), ParserError> { + let mut evaluated_macros = HashMap::with_capacity(contract.macros.len()); + for contract_macro in contract.macros.iter().filter(|m| m.outlined) { + let (body_statements_take, body_statements_return) = + evaluate_macro(&contract_macro.name, &contract.macros, &mut evaluated_macros)?; + if body_statements_take.abs() != contract_macro.takes as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes), + hint: Some(format!( + "Fn {} specified to take {} elements from the stack, but it takes {}", + contract_macro.name, + contract_macro.takes, + body_statements_take.abs() + )), + spans: contract_macro.span.clone(), + }) + } + if body_statements_return != contract_macro.returns as i16 { + return Err(ParserError { + kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns), + hint: Some(format!( + "Fn {} specified to return {} elements to the stack, but it returns {}", + contract_macro.name, contract_macro.returns, body_statements_return + )), + spans: contract_macro.span.clone(), + }) + } + } + Ok(()) +} diff --git a/huff_utils/src/evm.rs b/huff_utils/src/evm.rs index f485602c..96bdc276 100644 --- a/huff_utils/src/evm.rs +++ b/huff_utils/src/evm.rs @@ -925,38 +925,38 @@ impl Opcode { Opcode::Push30 => (0, 1), Opcode::Push31 => (0, 1), Opcode::Push32 => (0, 1), - Opcode::Dup1 => (0, 1), - Opcode::Dup2 => (0, 1), - Opcode::Dup3 => (0, 1), - Opcode::Dup4 => (0, 1), - Opcode::Dup5 => (0, 1), - Opcode::Dup6 => (0, 1), - Opcode::Dup7 => (0, 1), - Opcode::Dup8 => (0, 1), - Opcode::Dup9 => (0, 1), - Opcode::Dup10 => (0, 1), - Opcode::Dup11 => (0, 1), - Opcode::Dup12 => (0, 1), - Opcode::Dup13 => (0, 1), - Opcode::Dup14 => (0, 1), - Opcode::Dup15 => (0, 1), - Opcode::Dup16 => (0, 1), - Opcode::Swap1 => (0, 0), - Opcode::Swap2 => (0, 0), - Opcode::Swap3 => (0, 0), - Opcode::Swap4 => (0, 0), - Opcode::Swap5 => (0, 0), - Opcode::Swap6 => (0, 0), - Opcode::Swap7 => (0, 0), - Opcode::Swap8 => (0, 0), - Opcode::Swap9 => (0, 0), - Opcode::Swap10 => (0, 0), - Opcode::Swap11 => (0, 0), - Opcode::Swap12 => (0, 0), - Opcode::Swap13 => (0, 0), - Opcode::Swap14 => (0, 0), - Opcode::Swap15 => (0, 0), - Opcode::Swap16 => (0, 0), + Opcode::Dup1 => (1, 2), + Opcode::Dup2 => (2, 3), + Opcode::Dup3 => (3, 4), + Opcode::Dup4 => (4, 5), + Opcode::Dup5 => (5, 6), + Opcode::Dup6 => (6, 7), + Opcode::Dup7 => (7, 8), + Opcode::Dup8 => (8, 9), + Opcode::Dup9 => (9, 10), + Opcode::Dup10 => (10, 11), + Opcode::Dup11 => (11, 12), + Opcode::Dup12 => (12, 13), + Opcode::Dup13 => (13, 14), + Opcode::Dup14 => (14, 15), + Opcode::Dup15 => (15, 16), + Opcode::Dup16 => (16, 17), + Opcode::Swap1 => (2, 2), + Opcode::Swap2 => (3, 3), + Opcode::Swap3 => (4, 4), + Opcode::Swap4 => (5, 5), + Opcode::Swap5 => (6, 6), + Opcode::Swap6 => (7, 7), + Opcode::Swap7 => (8, 8), + Opcode::Swap8 => (9, 9), + Opcode::Swap9 => (10, 10), + Opcode::Swap10 => (11, 11), + Opcode::Swap11 => (12, 12), + Opcode::Swap12 => (13, 13), + Opcode::Swap13 => (14, 14), + Opcode::Swap14 => (15, 15), + Opcode::Swap15 => (16, 16), + Opcode::Swap16 => (17, 17), Opcode::Log0 => (2, 0), Opcode::Log1 => (3, 0), Opcode::Log2 => (4, 0),