Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Enable use of macros as macro args #265

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions huff_codegen/src/irgen/arg_calls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use huff_utils::prelude::*;
use std::str::FromStr;

use crate::Codegen;

// Arguments can be literals, labels, opcodes, or constants
// !! IF THERE IS AMBIGUOUS NOMENCLATURE
// !! (E.G. BOTH OPCODE AND LABEL ARE THE SAME STRING)
Expand All @@ -18,6 +20,9 @@ pub fn bubble_arg_call(
// mis: Parent macro invocations and their indices
mis: &mut Vec<(usize, MacroInvocation)>,
jump_table: &mut JumpTable,
circular_codesize_invocations: &mut CircularCodeSizeIndices,
label_indices: &mut LabelIndices,
table_instances: &mut Jumps,
) -> Result<(), CodegenError> {
let starting_offset = *offset;

Expand Down Expand Up @@ -67,33 +72,46 @@ pub fn bubble_arg_call(
};
return if last_mi.1.macro_name.eq(&macro_def.name) {
bubble_arg_call(
arg_name,
ac,
bytes,
&bubbled_macro_invocation,
contract,
&mut new_scope,
offset,
&mut Vec::from(&mis[..mis.len().saturating_sub(1)]),
jump_table,
circular_codesize_invocations,
label_indices,
table_instances,
)
} else {
bubble_arg_call(
arg_name,
&ac.to_string(),
bytes,
&bubbled_macro_invocation,
contract,
&mut new_scope,
offset,
mis,
jump_table,
circular_codesize_invocations,
label_indices,
table_instances,
)
}
}
MacroArg::Ident(iden) => {
tracing::debug!(target: "codegen", "Found MacroArg::Ident IN \"{}\" Macro Invocation: \"{}\"!", macro_invoc.1.macro_name, iden);

// Check for a constant first
if let Some(constant) = contract
// The opcode check needs to happens before the constants lookup
// because otherwise the mutex can deadlock when bubbling up to
// resolve macros as arguments.
if let Ok(o) = Opcode::from_str(iden) {
tracing::debug!(target: "codegen", "Found Opcode: {}", o);
let b = Bytes(o.to_string());
*offset += b.0.len() / 2;
bytes.push((starting_offset, b));
} else if let Some(constant) = contract
.constants
.lock()
.map_err(|_| {
Expand Down Expand Up @@ -121,17 +139,66 @@ pub fn bubble_arg_call(
kind: CodegenErrorKind::StoragePointersNotDerived,
span: AstSpan(vec![]),
token: None,
})
});
}
};
*offset += push_bytes.len() / 2;
tracing::info!(target: "codegen", "OFFSET: {}, PUSH BYTES: {:?}", offset, push_bytes);
bytes.push((starting_offset, Bytes(push_bytes)));
} else if let Ok(o) = Opcode::from_str(iden) {
tracing::debug!(target: "codegen", "Found Opcode: {}", o);
let b = Bytes(o.to_string());
*offset += b.0.len() / 2;
bytes.push((starting_offset, b));
} else if let Some(ir_macro) = contract.find_macro_by_name(iden) {
let new_scope = ir_macro.clone();
scope.push(new_scope);
tracing::debug!(target: "codegen", "ARG CALL IS MACRO: {}", iden);
tracing::debug!(target: "codegen", "CURRENT MACRO DEF: {}", macro_def.name);

mis.push((
*offset,
MacroInvocation {
args: vec![],
span: AstSpan(vec![]),
macro_name: iden.to_string(),
},
));

let mut res: BytecodeRes = match Codegen::macro_to_bytecode(
ir_macro.clone(),
contract,
scope,
*offset,
mis,
false,
Some(circular_codesize_invocations),
) {
Ok(r) => r,
Err(e) => {
tracing::error!(
target: "codegen",
"FAILED TO RECURSE INTO MACRO \"{}\"",
ir_macro.name
);
return Err(e);
}
};

for j in res.unmatched_jumps.iter_mut() {
let new_index = j.bytecode_index;
j.bytecode_index = 0;
let mut new_jumps = if let Some(jumps) = jump_table.get(&new_index)
{
jumps.clone()
} else {
vec![]
};
new_jumps.push(j.clone());
jump_table.insert(new_index, new_jumps);
}
table_instances.extend(res.table_instances);
label_indices.extend(res.label_indices);

// Increase offset by byte length of recursed macro
*offset += res.bytes.iter().map(|(_, b)| b.0.len()).sum::<usize>() / 2;
// Add the macro's bytecode to the final result
res.bytes.iter().for_each(|(a, b)| bytes.push((*a, b.clone())));
} else {
tracing::debug!(target: "codegen", "Found Label Call: {}", iden);

Expand Down
3 changes: 3 additions & 0 deletions huff_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ impl Codegen {
&mut offset,
mis,
&mut jump_table,
circular_codesize_invocations,
&mut label_indices,
&mut table_instances,
)?
}
}
Expand Down
120 changes: 120 additions & 0 deletions huff_core/tests/macro_invoc_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,123 @@ fn test_bubbled_constant_macro_arg() {
// Check the bytecode
assert_eq!(bytecode.to_lowercase(), expected_bytecode.to_lowercase());
}


#[test]
fn test_bubbled_arg_with_different_name() {
let source = r#"
#define macro MACRO_A(arg_a) = takes(0) returns(0) {
<arg_a>
}
#define macro MACRO_B(arg_b) =takes(0) returns(0) {
MACRO_A(<arg_b>)
}
#define macro MAIN() = takes(0) returns(0){
MACRO_B(0x01)
}
"#;

// Lex + Parse
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);
let mut contract = parser.parse().unwrap();
contract.derive_storage_pointers();

// Create main and constructor bytecode
let main_bytecode = Codegen::generate_main_bytecode(&contract, None).unwrap();

// Full expected bytecode output (generated from huffc) (placed here as a reference)
let expected_bytecode = "6001";

// Check the bytecode
assert_eq!(main_bytecode, expected_bytecode);

}

#[test]
fn test_macro_macro_arg() {
let source = r#"
#define constant TWO = 0x02

#define macro MUL_BY_10() = takes(1) returns (1) {
0x0a mul
}

#define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) {
<value> <macro>
}

#define macro MAIN() = takes(0) returns(0) {
EXEC_WITH_VALUE(TWO, MUL_BY_10)
}
"#;

// Lex + Parse
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);
let mut contract = parser.parse().unwrap();
contract.derive_storage_pointers();

// Create main and constructor bytecode
let main_bytecode = Codegen::generate_main_bytecode(&contract, None).unwrap();

// Full expected bytecode output (generated from huffc) (placed here as a reference)
let expected_bytecode = "6002600a02";

// Check the bytecode
assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase());
}

#[test]
fn test_bubbled_macro_macro_arg() {
let source = r#"
#define constant TWO = 0x02

#define macro MUL_BY_10() = takes(1) returns (1) {
0x0a mul
}

#define macro DO_OP(op) = takes(0) returns(0) {
<op>
}

#define macro DIV_BY_5() = takes(1) returns (1) {
0x05 swap1 DO_OP(div)
}

#define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) {
<value> <macro>
}

#define macro SUM_RESULTS(value, macro1, macro2) = takes(0) returns (1) {
EXEC_WITH_VALUE(<value>, <macro1>)
EXEC_WITH_VALUE(<value>, <macro2>)
add
}

#define macro MAIN() = takes(0) returns(0) {
SUM_RESULTS(TWO, MUL_BY_10, DIV_BY_5)
}
"#;

// Lex + Parse
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);
let mut contract = parser.parse().unwrap();
contract.derive_storage_pointers();

// Create main and constructor bytecode
let main_bytecode = Codegen::generate_main_bytecode(&contract, None).unwrap();

// Full expected bytecode output (generated from huffc) (placed here as a reference)
let expected_bytecode = "6002600a0260026005900401";

// Check the bytecode
assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase());
}