diff --git a/huff_codegen/src/irgen/arg_calls.rs b/huff_codegen/src/irgen/arg_calls.rs index 09e81621..2d3343f4 100644 --- a/huff_codegen/src/irgen/arg_calls.rs +++ b/huff_codegen/src/irgen/arg_calls.rs @@ -1,6 +1,8 @@ use huff_utils::prelude::*; use std::str::FromStr; +use crate::Codegen; + // Arguments can be literals, labels, opcodes, or constants // !! IF THERE IS AMBIGUOUS NOMENCLATURE // !! (E.G. BOTH OPCODE AND LABEL ARE THE SAME STRING) @@ -9,6 +11,7 @@ use std::str::FromStr; /// Arg Call Bubbling #[allow(clippy::too_many_arguments)] pub fn bubble_arg_call( + evm_version: &EVMVersion, arg_name: &str, bytes: &mut Vec<(usize, Bytes)>, macro_def: &MacroDefinition, @@ -18,6 +21,9 @@ pub fn bubble_arg_call( // mis: Parent macro invocations and their indices mis: &mut [(usize, MacroInvocation)], jump_table: &mut JumpTable, + circular_codesize_invocations: &mut CircularCodeSizeIndices, + label_indices: &mut LabelIndices, + table_instances: &mut Jumps, ) -> Result<(), CodegenError> { let starting_offset = *offset; @@ -70,6 +76,7 @@ pub fn bubble_arg_call( let ac_ = &ac.to_string(); return if last_mi.1.macro_name.eq(¯o_def.name) { bubble_arg_call( + evm_version, ac_, bytes, bubbled_macro_invocation, @@ -78,9 +85,13 @@ pub fn bubble_arg_call( offset, &mut mis[..mis_len.saturating_sub(1)], jump_table, + circular_codesize_invocations, + label_indices, + table_instances, ) } else { bubble_arg_call( + evm_version, ac_, bytes, bubbled_macro_invocation, @@ -89,14 +100,24 @@ pub fn bubble_arg_call( offset, mis, jump_table, + circular_codesize_invocations, + label_indices, + table_instances, ) } } MacroArg::Ident(iden) => { tracing::debug!(target: "codegen", "Found MacroArg::Ident IN \"{}\" Macro Invocation: \"{}\"!", macro_invoc.1.macro_name, iden); - // Check for a constant first - if let Some(constant) = contract + // The opcode check needs to happens before the constants lookup + // because otherwise the mutex can deadlock when bubbling up to + // resolve macros as arguments. + if let Ok(o) = Opcode::from_str(iden) { + tracing::debug!(target: "codegen", "Found Opcode: {}", o); + let b = Bytes(o.to_string()); + *offset += b.0.len() / 2; + bytes.push((starting_offset, b)); + } else if let Some(constant) = contract .constants .lock() .map_err(|_| { @@ -130,11 +151,62 @@ pub fn bubble_arg_call( *offset += push_bytes.len() / 2; tracing::info!(target: "codegen", "OFFSET: {}, PUSH BYTES: {:?}", offset, push_bytes); bytes.push((starting_offset, Bytes(push_bytes))); - } else if let Ok(o) = Opcode::from_str(iden) { - tracing::debug!(target: "codegen", "Found Opcode: {}", o); - let b = Bytes(o.to_string()); - *offset += b.0.len() / 2; - bytes.push((starting_offset, b)); + } else if let Some(ir_macro) = contract.find_macro_by_name(iden) { + tracing::debug!(target: "codegen", "ARG CALL IS MACRO: {}", iden); + tracing::debug!(target: "codegen", "CURRENT MACRO DEF: {}", macro_def.name); + + let mut new_scopes = scope.to_vec(); + new_scopes.push(ir_macro); + let mut new_mis = mis.to_vec(); + new_mis.push(( + *offset, + MacroInvocation { + macro_name: iden.to_string(), + args: vec![], + span: AstSpan(vec![]), + }, + )); + + let mut res: BytecodeRes = match Codegen::macro_to_bytecode( + evm_version, + ir_macro, + contract, + &mut new_scopes, + *offset, + &mut new_mis, + false, + Some(circular_codesize_invocations), + ) { + Ok(r) => r, + Err(e) => { + tracing::error!( + target: "codegen", + "FAILED TO RECURSE INTO MACRO \"{}\"", + ir_macro.name + ); + return Err(e) + } + }; + + for j in res.unmatched_jumps.iter_mut() { + let new_index = j.bytecode_index; + j.bytecode_index = 0; + let mut new_jumps = if let Some(jumps) = jump_table.get(&new_index) + { + jumps.clone() + } else { + vec![] + }; + new_jumps.push(j.clone()); + jump_table.insert(new_index, new_jumps); + } + table_instances.extend(res.table_instances); + label_indices.extend(res.label_indices); + + // Increase offset by byte length of recursed macro + *offset += res.bytes.iter().map(|(_, b)| b.0.len()).sum::() / 2; + // Add the macro's bytecode to the final result + res.bytes.iter().for_each(|(a, b)| bytes.push((*a, b.clone()))); } else { tracing::debug!(target: "codegen", "Found Label Call: {}", iden); diff --git a/huff_codegen/src/lib.rs b/huff_codegen/src/lib.rs index e281235f..84657b14 100644 --- a/huff_codegen/src/lib.rs +++ b/huff_codegen/src/lib.rs @@ -340,6 +340,7 @@ impl Codegen { // Bubble up arg call by looking through the previous scopes. // Once the arg value is found, add it to `bytes` bubble_arg_call( + evm_version, arg_name, &mut bytes, macro_def, @@ -348,6 +349,9 @@ impl Codegen { &mut offset, mis, &mut jump_table, + circular_codesize_invocations, + &mut label_indices, + &mut table_instances, )? } } diff --git a/huff_core/tests/macro_invoc_args.rs b/huff_core/tests/macro_invoc_args.rs index 1ddc850f..575f0b66 100644 --- a/huff_core/tests/macro_invoc_args.rs +++ b/huff_core/tests/macro_invoc_args.rs @@ -308,3 +308,93 @@ fn test_bubbled_arg_with_different_name() { // Check the bytecode assert_eq!(main_bytecode, expected_bytecode); } + +#[test] +fn test_macro_macro_arg() { + let source = r#" + #define constant TWO = 0x02 + + #define macro MUL_BY_10() = takes(1) returns (1) { + 0x0a mul + } + + #define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) { + + } + + #define macro MAIN() = takes(0) returns(0) { + EXEC_WITH_VALUE(TWO, MUL_BY_10) + } + "#; + + // Lex + Parse + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + let mut contract = parser.parse().unwrap(); + contract.derive_storage_pointers(); + + let evm_version = EVMVersion::default(); + + // Create main and constructor bytecode + let main_bytecode = Codegen::generate_main_bytecode(&evm_version, &contract, None).unwrap(); + + // Full expected bytecode output (generated from huffc) (placed here as a reference) + let expected_bytecode = "6002600a02"; + + // Check the bytecode + assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase()); +} + +#[test] +fn test_bubbled_macro_macro_arg() { + let source = r#" + #define constant TWO = 0x02 + + #define macro MUL_BY_10() = takes(1) returns (1) { + 0x0a mul + } + + #define macro DO_OP(op) = takes(0) returns(0) { + + } + + #define macro DIV_BY_5() = takes(1) returns (1) { + 0x05 swap1 DO_OP(div) + } + + #define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) { + + } + + #define macro SUM_RESULTS(value, macro1, macro2) = takes(0) returns (1) { + EXEC_WITH_VALUE(, ) + EXEC_WITH_VALUE(, ) + add + } + + #define macro MAIN() = takes(0) returns(0) { + SUM_RESULTS(TWO, MUL_BY_10, DIV_BY_5) + } + "#; + + // Lex + Parse + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source.source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + let mut contract = parser.parse().unwrap(); + contract.derive_storage_pointers(); + + let evm_version = EVMVersion::default(); + + // Create main and constructor bytecode + let main_bytecode = Codegen::generate_main_bytecode(&evm_version, &contract, None).unwrap(); + + // Full expected bytecode output (generated from huffc) (placed here as a reference) + let expected_bytecode = "6002600a0260026005900401"; + + // Check the bytecode + assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase()); +}