diff --git a/huff_codegen/src/irgen/arg_calls.rs b/huff_codegen/src/irgen/arg_calls.rs index 58ccbe51..e4ac4da3 100644 --- a/huff_codegen/src/irgen/arg_calls.rs +++ b/huff_codegen/src/irgen/arg_calls.rs @@ -1,6 +1,8 @@ use huff_utils::prelude::*; use std::str::FromStr; +use crate::Codegen; + // Arguments can be literals, labels, opcodes, or constants // !! IF THERE IS AMBIGUOUS NOMENCLATURE // !! (E.G. BOTH OPCODE AND LABEL ARE THE SAME STRING) @@ -18,6 +20,9 @@ pub fn bubble_arg_call( // mis: Parent macro invocations and their indices mis: &mut Vec<(usize, MacroInvocation)>, jump_table: &mut JumpTable, + circular_codesize_invocations: &mut CircularCodeSizeIndices, + label_indices: &mut LabelIndices, + table_instances: &mut Jumps, ) -> Result<(), CodegenError> { let starting_offset = *offset; @@ -67,7 +72,7 @@ pub fn bubble_arg_call( }; return if last_mi.1.macro_name.eq(¯o_def.name) { bubble_arg_call( - arg_name, + ac, bytes, &bubbled_macro_invocation, contract, @@ -75,10 +80,13 @@ pub fn bubble_arg_call( offset, &mut Vec::from(&mis[..mis.len().saturating_sub(1)]), jump_table, + circular_codesize_invocations, + label_indices, + table_instances, ) } else { bubble_arg_call( - arg_name, + &ac.to_string(), bytes, &bubbled_macro_invocation, contract, @@ -86,14 +94,24 @@ pub fn bubble_arg_call( offset, mis, jump_table, + circular_codesize_invocations, + label_indices, + table_instances, ) } } MacroArg::Ident(iden) => { tracing::debug!(target: "codegen", "Found MacroArg::Ident IN \"{}\" Macro Invocation: \"{}\"!", macro_invoc.1.macro_name, iden); - // Check for a constant first - if let Some(constant) = contract + // The opcode check needs to happens before the constants lookup + // because otherwise the mutex can deadlock when bubbling up to + // resolve macros as arguments. + if let Ok(o) = Opcode::from_str(iden) { + tracing::debug!(target: "codegen", "Found Opcode: {}", o); + let b = Bytes(o.to_string()); + *offset += b.0.len() / 2; + bytes.push((starting_offset, b)); + } else if let Some(constant) = contract .constants .lock() .map_err(|_| { @@ -121,17 +139,66 @@ pub fn bubble_arg_call( kind: CodegenErrorKind::StoragePointersNotDerived, span: AstSpan(vec![]), token: None, - }) + }); } }; *offset += push_bytes.len() / 2; tracing::info!(target: "codegen", "OFFSET: {}, PUSH BYTES: {:?}", offset, push_bytes); bytes.push((starting_offset, Bytes(push_bytes))); - } else if let Ok(o) = Opcode::from_str(iden) { - tracing::debug!(target: "codegen", "Found Opcode: {}", o); - let b = Bytes(o.to_string()); - *offset += b.0.len() / 2; - bytes.push((starting_offset, b)); + } else if let Some(ir_macro) = contract.find_macro_by_name(iden) { + let new_scope = ir_macro.clone(); + scope.push(new_scope); + tracing::debug!(target: "codegen", "ARG CALL IS MACRO: {}", iden); + tracing::debug!(target: "codegen", "CURRENT MACRO DEF: {}", macro_def.name); + + mis.push(( + *offset, + MacroInvocation { + args: vec![], + span: AstSpan(vec![]), + macro_name: iden.to_string(), + }, + )); + + let mut res: BytecodeRes = match Codegen::macro_to_bytecode( + ir_macro.clone(), + contract, + scope, + *offset, + mis, + false, + Some(circular_codesize_invocations), + ) { + Ok(r) => r, + Err(e) => { + tracing::error!( + target: "codegen", + "FAILED TO RECURSE INTO MACRO \"{}\"", + ir_macro.name + ); + return Err(e); + } + }; + + for j in res.unmatched_jumps.iter_mut() { + let new_index = j.bytecode_index; + j.bytecode_index = 0; + let mut new_jumps = if let Some(jumps) = jump_table.get(&new_index) + { + jumps.clone() + } else { + vec![] + }; + new_jumps.push(j.clone()); + jump_table.insert(new_index, new_jumps); + } + table_instances.extend(res.table_instances); + label_indices.extend(res.label_indices); + + // Increase offset by byte length of recursed macro + *offset += res.bytes.iter().map(|(_, b)| b.0.len()).sum::() / 2; + // Add the macro's bytecode to the final result + res.bytes.iter().for_each(|(a, b)| bytes.push((*a, b.clone()))); } else { tracing::debug!(target: "codegen", "Found Label Call: {}", iden); diff --git a/huff_codegen/src/lib.rs b/huff_codegen/src/lib.rs index 3c77f813..bd980a61 100644 --- a/huff_codegen/src/lib.rs +++ b/huff_codegen/src/lib.rs @@ -341,6 +341,9 @@ impl Codegen { &mut offset, mis, &mut jump_table, + circular_codesize_invocations, + &mut label_indices, + &mut table_instances, )? } } diff --git a/huff_core/tests/macro_invoc_args.rs b/huff_core/tests/macro_invoc_args.rs index 0b498432..dbdb1562 100644 --- a/huff_core/tests/macro_invoc_args.rs +++ b/huff_core/tests/macro_invoc_args.rs @@ -260,3 +260,123 @@ fn test_bubbled_constant_macro_arg() { // Check the bytecode assert_eq!(bytecode.to_lowercase(), expected_bytecode.to_lowercase()); } + + +#[test] +fn test_bubbled_arg_with_different_name() { + let source = r#" + #define macro MACRO_A(arg_a) = takes(0) returns(0) { + + } + #define macro MACRO_B(arg_b) =takes(0) returns(0) { + MACRO_A() + } + #define macro MAIN() = takes(0) returns(0){ + MACRO_B(0x01) + } + "#; + + // Lex + Parse + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + let mut contract = parser.parse().unwrap(); + contract.derive_storage_pointers(); + + // Create main and constructor bytecode + let main_bytecode = Codegen::generate_main_bytecode(&contract, None).unwrap(); + + // Full expected bytecode output (generated from huffc) (placed here as a reference) + let expected_bytecode = "6001"; + + // Check the bytecode + assert_eq!(main_bytecode, expected_bytecode); + +} + +#[test] +fn test_macro_macro_arg() { + let source = r#" + #define constant TWO = 0x02 + + #define macro MUL_BY_10() = takes(1) returns (1) { + 0x0a mul + } + + #define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) { + + } + + #define macro MAIN() = takes(0) returns(0) { + EXEC_WITH_VALUE(TWO, MUL_BY_10) + } + "#; + + // Lex + Parse + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + let mut contract = parser.parse().unwrap(); + contract.derive_storage_pointers(); + + // Create main and constructor bytecode + let main_bytecode = Codegen::generate_main_bytecode(&contract, None).unwrap(); + + // Full expected bytecode output (generated from huffc) (placed here as a reference) + let expected_bytecode = "6002600a02"; + + // Check the bytecode + assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase()); +} + +#[test] +fn test_bubbled_macro_macro_arg() { + let source = r#" + #define constant TWO = 0x02 + + #define macro MUL_BY_10() = takes(1) returns (1) { + 0x0a mul + } + + #define macro DO_OP(op) = takes(0) returns(0) { + + } + + #define macro DIV_BY_5() = takes(1) returns (1) { + 0x05 swap1 DO_OP(div) + } + + #define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) { + + } + + #define macro SUM_RESULTS(value, macro1, macro2) = takes(0) returns (1) { + EXEC_WITH_VALUE(, ) + EXEC_WITH_VALUE(, ) + add + } + + #define macro MAIN() = takes(0) returns(0) { + SUM_RESULTS(TWO, MUL_BY_10, DIV_BY_5) + } + "#; + + // Lex + Parse + let flattened_source = FullFileSource { source, file: None, spans: vec![] }; + let lexer = Lexer::new(flattened_source); + let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::>(); + let mut parser = Parser::new(tokens, None); + let mut contract = parser.parse().unwrap(); + contract.derive_storage_pointers(); + + // Create main and constructor bytecode + let main_bytecode = Codegen::generate_main_bytecode(&contract, None).unwrap(); + + // Full expected bytecode output (generated from huffc) (placed here as a reference) + let expected_bytecode = "6002600a0260026005900401"; + + // Check the bytecode + assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase()); +}