Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
feat: add check for macro invocations before returning contract parse
Browse files Browse the repository at this point in the history
  • Loading branch information
igorline committed Oct 9, 2023
1 parent 9dc6f61 commit ab98df2
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 89 deletions.
150 changes: 93 additions & 57 deletions huff_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#![forbid(unsafe_code)]
#![forbid(where_clauses_object_safety)]

use std::collections::HashMap;

use huff_utils::{
ast::*,
error::*,
Expand Down Expand Up @@ -136,6 +138,8 @@ impl Parser {
}
}

validate_macros(&contract)?;

Ok(contract)
}

Expand Down Expand Up @@ -521,63 +525,6 @@ impl Parser {

let macro_statements: Vec<Statement> = self.parse_body()?;

if outlined {
let (body_statements_take, body_statements_return) =
macro_statements.iter().fold((0i16, 0i16), |acc, st| {
let (statement_takes, statement_returns) = match st.ty {
StatementType::Literal(_) |
StatementType::Constant(_) |
StatementType::BuiltinFunctionCall(_) |
StatementType::ArgCall(_) |
StatementType::LabelCall(_) => (0i8, 1i8),
StatementType::Opcode(opcode) => {
if opcode.is_value_push() {
(0i8, 0i8)
} else {
let stack_changes = opcode.stack_changes();
(stack_changes.0 as i8, stack_changes.1 as i8)
}
}
StatementType::Label(_) => (0i8, 0i8),
StatementType::MacroInvocation(_) => {
todo!()
}
StatementType::Code(_) => {
todo!("should throw error")
}
};

// acc.1 is always non negative
// acc.0 is always non positive
let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 {
(acc.0 + acc.1 - statement_takes as i16, statement_returns as i16)
} else {
(acc.0, acc.1 - statement_takes as i16 + statement_returns as i16)
};
(stack_takes, stack_returns)
});
if body_statements_take.abs() != macro_takes as i16 {
return Err(ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes),
hint: Some(format!(
"Fn {macro_name} specified to take {macro_takes} elements from the stack, but it takes {}",
body_statements_take.abs()
)),
spans: AstSpan(self.spans.clone()),
});
}
if body_statements_return != macro_returns as i16 {
return Err(ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns),
hint: Some(format!(
"Fn {macro_name} specified to return {macro_returns} elements to the stack, but it returns {}",
body_statements_return
)),
spans: AstSpan(self.spans.clone()),
});
}
}

Ok(MacroDefinition::new(
macro_name,
decorator,
Expand Down Expand Up @@ -1336,3 +1283,92 @@ impl Parser {
}
}
}

/// Function used to evaluate macro statements. Returns number of elements taken from the stack and
/// returned to the stack
pub fn evaluate_macro(
_macro_name: &str,
_macros: &[MacroDefinition],
_evaluated_macros: &mut HashMap<String, (i16, i16)>,
) -> Result<(i16, i16), ParserError> {
if _evaluated_macros.contains_key(_macro_name) {
return Ok(*_evaluated_macros.get(_macro_name).unwrap())
}

let _macro = _macros.iter().find(|m| m.name.as_str() == _macro_name).unwrap();
let (body_statements_take, body_statements_return) =
_macro.statements.iter().fold((0i16, 0i16), |acc, st| {
let (statement_takes, statement_returns) = match &st.ty {
StatementType::Literal(_) |
StatementType::Constant(_) |
StatementType::BuiltinFunctionCall(_) |
StatementType::ArgCall(_) => (0i8, 1i8),
StatementType::LabelCall(_) => (0i8, 1i8),
StatementType::Opcode(opcode) => {
if opcode.is_value_push() {
(0i8, 0i8)
} else {
let stack_changes = opcode.stack_changes();
(stack_changes.0 as i8, stack_changes.1 as i8)
}
}
StatementType::Label(_) => (0i8, 0i8),
StatementType::MacroInvocation(MacroInvocation {
macro_name,
args: _,
span: _,
}) => {
let (takes, returns) =
evaluate_macro(macro_name, _macros, _evaluated_macros).unwrap();
(takes.abs() as i8, returns as i8)
}
StatementType::Code(_) => {
todo!("should throw error")
}
};

// acc.1 is always non negative
// acc.0 is always non positive
let (stack_takes, stack_returns) = if statement_takes as i16 > acc.1 {
(acc.0 + acc.1 - statement_takes as i16, statement_returns as i16)
} else {
(acc.0, acc.1 - statement_takes as i16 + statement_returns as i16)
};
(stack_takes, stack_returns)
});

_evaluated_macros.insert(_macro.name.clone(), (body_statements_take, body_statements_return));
Ok((body_statements_take, body_statements_return))
}

/// Function used to validate takes and returns of outlined macros in the contract
pub fn validate_macros(contract: &Contract) -> Result<(), ParserError> {
let mut evaluated_macros = HashMap::new();
for _macro in contract.macros.iter().filter(|m| m.outlined) {
let (body_statements_take, body_statements_return) =
evaluate_macro(&_macro.name, &contract.macros, &mut evaluated_macros)?;
if body_statements_take.abs() != _macro.takes as i16 {
return Err(ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Takes),
hint: Some(format!(
"Fn {} specified to take {} elements from the stack, but it takes {}",
_macro.name,
_macro.takes,
body_statements_take.abs()
)),
spans: _macro.span.clone(),
})
}
if body_statements_return != _macro.returns as i16 {
return Err(ParserError {
kind: ParserErrorKind::InvalidStackAnnotation(TokenKind::Returns),
hint: Some(format!(
"Fn {} specified to return {} elements to the stack, but it returns {}",
_macro.name, _macro.returns, body_statements_return
)),
spans: _macro.span.clone(),
})
}
}
Ok(())
}
64 changes: 32 additions & 32 deletions huff_utils/src/evm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -925,38 +925,38 @@ impl Opcode {
Opcode::Push30 => (0, 1),
Opcode::Push31 => (0, 1),
Opcode::Push32 => (0, 1),
Opcode::Dup1 => (0, 1),
Opcode::Dup2 => (0, 1),
Opcode::Dup3 => (0, 1),
Opcode::Dup4 => (0, 1),
Opcode::Dup5 => (0, 1),
Opcode::Dup6 => (0, 1),
Opcode::Dup7 => (0, 1),
Opcode::Dup8 => (0, 1),
Opcode::Dup9 => (0, 1),
Opcode::Dup10 => (0, 1),
Opcode::Dup11 => (0, 1),
Opcode::Dup12 => (0, 1),
Opcode::Dup13 => (0, 1),
Opcode::Dup14 => (0, 1),
Opcode::Dup15 => (0, 1),
Opcode::Dup16 => (0, 1),
Opcode::Swap1 => (0, 0),
Opcode::Swap2 => (0, 0),
Opcode::Swap3 => (0, 0),
Opcode::Swap4 => (0, 0),
Opcode::Swap5 => (0, 0),
Opcode::Swap6 => (0, 0),
Opcode::Swap7 => (0, 0),
Opcode::Swap8 => (0, 0),
Opcode::Swap9 => (0, 0),
Opcode::Swap10 => (0, 0),
Opcode::Swap11 => (0, 0),
Opcode::Swap12 => (0, 0),
Opcode::Swap13 => (0, 0),
Opcode::Swap14 => (0, 0),
Opcode::Swap15 => (0, 0),
Opcode::Swap16 => (0, 0),
Opcode::Dup1 => (1, 2),
Opcode::Dup2 => (2, 3),
Opcode::Dup3 => (3, 4),
Opcode::Dup4 => (4, 5),
Opcode::Dup5 => (5, 6),
Opcode::Dup6 => (6, 7),
Opcode::Dup7 => (7, 8),
Opcode::Dup8 => (8, 9),
Opcode::Dup9 => (9, 10),
Opcode::Dup10 => (10, 11),
Opcode::Dup11 => (11, 12),
Opcode::Dup12 => (12, 13),
Opcode::Dup13 => (13, 14),
Opcode::Dup14 => (14, 15),
Opcode::Dup15 => (15, 16),
Opcode::Dup16 => (16, 17),
Opcode::Swap1 => (2, 2),
Opcode::Swap2 => (3, 3),
Opcode::Swap3 => (4, 4),
Opcode::Swap4 => (5, 5),
Opcode::Swap5 => (6, 6),
Opcode::Swap6 => (7, 7),
Opcode::Swap7 => (8, 8),
Opcode::Swap8 => (9, 9),
Opcode::Swap9 => (10, 10),
Opcode::Swap10 => (11, 11),
Opcode::Swap11 => (12, 12),
Opcode::Swap12 => (13, 13),
Opcode::Swap13 => (14, 14),
Opcode::Swap14 => (15, 15),
Opcode::Swap15 => (16, 16),
Opcode::Swap16 => (17, 17),
Opcode::Log0 => (2, 0),
Opcode::Log1 => (3, 0),
Opcode::Log2 => (4, 0),
Expand Down

0 comments on commit ab98df2

Please sign in to comment.