Skip to content

Commit

Permalink
Merge pull request #358 from 0xPolygonMiden/greenhat/i303-add-call-op
Browse files Browse the repository at this point in the history
[4/x] Implement VM's `call` op in IR and codegen
  • Loading branch information
bitwalker authored Dec 16, 2024
2 parents 249271d + 2a28cd2 commit 7c99e32
Show file tree
Hide file tree
Showing 21 changed files with 320 additions and 149 deletions.
35 changes: 35 additions & 0 deletions codegen/masm/src/codegen/emit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ impl<'a> InstOpEmitter<'a> {
self.emitter.exec(import, span);
}

pub fn call(&mut self, callee: hir::FunctionIdent, span: SourceSpan) {
let import = self.dfg.get_import(&callee).unwrap();
self.emitter.call(import, span);
}

pub fn syscall(&mut self, callee: hir::FunctionIdent, span: SourceSpan) {
let import = self.dfg.get_import(&callee).unwrap();
self.emitter.syscall(import, span);
Expand Down Expand Up @@ -1890,6 +1895,36 @@ mod tests {
assert_eq!(emitter.stack()[0], return_ty);
}

#[test]
fn op_emitter_u32_call_test() {
use midenc_hir::ExternalFunction;

let mut function = setup();
let entry = function.body.id();
let mut stack = OperandStack::default();
let mut emitter = OpEmitter::new(&mut function, entry, &mut stack);

let return_ty = Type::Array(Box::new(Type::U32), 1);
let callee = ExternalFunction {
id: "test::add".parse().unwrap(),
signature: Signature::new(
[AbiParam::new(Type::U32), AbiParam::new(Type::I1)],
[AbiParam::new(return_ty.clone())],
),
};

let t = Immediate::I1(true);
let one = Immediate::U32(1);

emitter.literal(t, SourceSpan::default());
emitter.literal(one, SourceSpan::default());
assert_eq!(emitter.stack_len(), 2);

emitter.call(&callee, SourceSpan::default());
assert_eq!(emitter.stack_len(), 1);
assert_eq!(emitter.stack()[0], return_ty);
}

#[test]
fn op_emitter_u32_load_test() {
let mut function = setup();
Expand Down
107 changes: 107 additions & 0 deletions codegen/masm/src/codegen/emit/primop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,113 @@ impl<'a> OpEmitter<'a> {
self.emit(Op::Exec(callee), span);
}

/// Calls the given procedure via a `call` instruction.
///
/// A function called using this operation is invoked in the **new** memory context
pub fn call(&mut self, callee: &hir::ExternalFunction, span: SourceSpan) {
let import = callee;
let callee = import.id;
let signature = &import.signature;
for i in 0..signature.arity() {
let param = &signature.params[i];
let arg = self.stack.pop().expect("operand stack is empty");
let ty = arg.ty();
// Validate the purpose matches
match param.purpose {
ArgumentPurpose::StructReturn => {
panic!(
"sret parameters are not supported in call instructions since they don't \
make sense in a new memory context"
);
}
ArgumentPurpose::Default => (),
}
// Validate that the argument type is valid for the parameter ABI
match param.extension {
// Types must match exactly
ArgumentExtension::None => {
assert_eq!(
ty, param.ty,
"invalid call to {callee}: invalid argument type for parameter at index \
{i}"
);
}
// Caller can provide a smaller type which will be zero-extended to the expected
// type
//
// However, the argument must be an unsigned integer, and of smaller or equal size
// in order for the types to differ
ArgumentExtension::Zext if ty != param.ty => {
assert!(
param.ty.is_unsigned_integer(),
"invalid function signature: zero-extension is only valid for unsigned \
integer types"
);
assert!(
ty.is_unsigned_integer(),
"invalid call to {callee}: invalid argument type for parameter at index \
{i}, expected unsigned integer type, got {ty}"
);
let expected_size = param.ty.size_in_bits();
let provided_size = param.ty.size_in_bits();
assert!(
provided_size <= expected_size,
"invalid call to {callee}: invalid argument type for parameter at index \
{i}, expected integer width to be <= {expected_size} bits"
);
// Zero-extend this argument
self.stack.push(arg);
self.zext(&param.ty, span);
self.stack.drop();
}
// Caller can provide a smaller type which will be sign-extended to the expected
// type
//
// However, the argument must be an integer which can fit in the range of the
// expected type
ArgumentExtension::Sext if ty != param.ty => {
assert!(
param.ty.is_signed_integer(),
"invalid function signature: sign-extension is only valid for signed \
integer types"
);
assert!(
ty.is_integer(),
"invalid call to {callee}: invalid argument type for parameter at index \
{i}, expected integer type, got {ty}"
);
let expected_size = param.ty.size_in_bits();
let provided_size = param.ty.size_in_bits();
if ty.is_unsigned_integer() {
assert!(
provided_size < expected_size,
"invalid call to {callee}: invalid argument type for parameter at \
index {i}, expected unsigned integer width to be < {expected_size} \
bits"
);
} else {
assert!(
provided_size <= expected_size,
"invalid call to {callee}: invalid argument type for parameter at \
index {i}, expected integer width to be <= {expected_size} bits"
);
}
// Push the operand back on the stack for `sext`
self.stack.push(arg);
self.sext(&param.ty, span);
self.stack.drop();
}
ArgumentExtension::Zext | ArgumentExtension::Sext => (),
}
}

for result in signature.results.iter().rev() {
self.push(result.ty.clone());
}

self.emit(Op::Call(callee), span);
}

/// Execute the given procedure as a syscall.
///
/// A function called using this operation is invoked in the same memory context as the caller.
Expand Down
3 changes: 2 additions & 1 deletion codegen/masm/src/codegen/emitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,8 @@ impl<'b, 'f: 'b> BlockEmitter<'b, 'f> {
let mut emitter = self.inst_emitter(inst_info.inst);
match op.op {
hir::Opcode::Syscall => emitter.syscall(op.callee, span),
hir::Opcode::Call => emitter.exec(op.callee, span),
hir::Opcode::Exec => emitter.exec(op.callee, span),
hir::Opcode::Call => emitter.call(op.callee, span),
opcode => unimplemented!("unrecognized procedure call opcode: '{opcode}'"),
}
}
Expand Down
2 changes: 1 addition & 1 deletion frontend-wasm/src/code_translator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ fn translate_call(
func_state.pushn(&results);
} else {
// no transformation needed
let call = builder.ins().call(func_id, args, span);
let call = builder.ins().exec(func_id, args, span);
let results = builder.inst_results(call);
func_state.popn(num_wasm_args);
func_state.pushn(results);
Expand Down
2 changes: 1 addition & 1 deletion frontend-wasm/src/intrinsics/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub(crate) fn convert_mem_intrinsics(
signature,
);
}
let call = builder.ins().call(func_id, &[], span);
let call = builder.ins().exec(func_id, &[], span);
let value = builder.data_flow_graph().first_result(call);
vec![value]
}
Expand Down
6 changes: 3 additions & 3 deletions frontend-wasm/src/miden_abi/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub fn no_transform(
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<Value> {
let call = builder.ins().call(func_id, args, span);
let call = builder.ins().exec(func_id, args, span);
let results = builder.inst_results(call);
results.to_vec()
}
Expand All @@ -99,7 +99,7 @@ pub fn list_return(
span: SourceSpan,
_diagnostics: &DiagnosticsHandler,
) -> Vec<Value> {
let call = builder.ins().call(func_id, args, span);
let call = builder.ins().exec(func_id, args, span);
let results = builder.inst_results(call);
assert_eq!(results.len(), 2, "List return strategy expects 2 results: length and pointer");
// Return the first result (length) only
Expand All @@ -116,7 +116,7 @@ pub fn return_via_pointer(
) -> Vec<Value> {
// Omit the last argument (pointer)
let args_wo_pointer = &args[0..args.len() - 1];
let call = builder.ins().call(func_id, args_wo_pointer, span);
let call = builder.ins().exec(func_id, args_wo_pointer, span);
let results = builder.inst_results(call).to_vec();
let ptr_arg = *args.last().unwrap();
let ptr_arg_ty = builder.data_flow_graph().value_type(ptr_arg).clone();
Expand Down
4 changes: 2 additions & 2 deletions hir-analysis/src/spill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ mod tests {
builder.ins().inttoptr(v5, Type::Ptr(Box::new(Type::U128)), SourceSpan::UNKNOWN);
let v7 = builder.ins().load(v6, SourceSpan::UNKNOWN);
let v8 = builder.ins().u64(1, SourceSpan::UNKNOWN);
call = builder.ins().call(example, &[v6, v4, v7, v7, v8], SourceSpan::UNKNOWN);
call = builder.ins().exec(example, &[v6, v4, v7, v7, v8], SourceSpan::UNKNOWN);
let v10 = builder.ins().add_imm_unchecked(v1, Immediate::U32(72), SourceSpan::UNKNOWN);
store = builder.ins().store(v3, v7, SourceSpan::UNKNOWN);
let v11 =
Expand Down Expand Up @@ -1519,7 +1519,7 @@ mod tests {
// block1
builder.switch_to_block(block1);
let v9 = builder.ins().u64(1, SourceSpan::UNKNOWN);
call = builder.ins().call(example, &[v6, v4, v7, v7, v9], SourceSpan::UNKNOWN);
call = builder.ins().exec(example, &[v6, v4, v7, v7, v9], SourceSpan::UNKNOWN);
let v10 = builder.func.dfg.first_result(call);
builder.ins().br(block3, &[v10], SourceSpan::UNKNOWN);

Expand Down
2 changes: 1 addition & 1 deletion hir-analysis/src/validation/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ impl<'a> InstTypeChecker<'a> {
}
Opcode::IsOdd => InstPattern::Exact(vec![TypePattern::Int], vec![Type::I1.into()]),
Opcode::Min | Opcode::Max => InstPattern::BinaryMatching(TypePattern::Int),
Opcode::Call | Opcode::Syscall => match node.as_ref() {
Opcode::Exec | Opcode::Call | Opcode::Syscall => match node.as_ref() {
Instruction::Call(Call { ref callee, .. }) => {
if let Some(import) = dfg.get_import(callee) {
let args = import
Expand Down
8 changes: 4 additions & 4 deletions hir-transform/src/spill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ mod tests {
builder.ins().inttoptr(v5, Type::Ptr(Box::new(Type::U128)), SourceSpan::UNKNOWN);
let v7 = builder.ins().load(v6, SourceSpan::UNKNOWN);
let v8 = builder.ins().u64(1, SourceSpan::UNKNOWN);
builder.ins().call(example, &[v6, v4, v7, v7, v8], SourceSpan::UNKNOWN);
builder.ins().exec(example, &[v6, v4, v7, v7, v8], SourceSpan::UNKNOWN);
let v10 = builder.ins().add_imm_unchecked(v1, Immediate::U32(72), SourceSpan::UNKNOWN);
builder.ins().store(v3, v7, SourceSpan::UNKNOWN);
let v11 =
Expand Down Expand Up @@ -773,7 +773,7 @@ mod tests {
(let (v8 u64) (const.u64 1))
(store.local local0 v2)
(store.local local1 v3)
(let (v9 u32) (call (#foo #example) v6 v4 v7 v7 v8))
(let (v9 u32) (exec (#foo #example) v6 v4 v7 v7 v8))
(let (v10 u32) (add.unchecked v1 72))
(let (v13 (ptr u128)) (load.local local1))
(store v13 v7)
Expand Down Expand Up @@ -844,7 +844,7 @@ mod tests {
// block1
builder.switch_to_block(block1);
let v9 = builder.ins().u64(1, SourceSpan::UNKNOWN);
let call = builder.ins().call(example, &[v6, v4, v7, v7, v9], SourceSpan::UNKNOWN);
let call = builder.ins().exec(example, &[v6, v4, v7, v7, v9], SourceSpan::UNKNOWN);
let v10 = builder.func.dfg.first_result(call);
builder.ins().br(block3, &[v10], SourceSpan::UNKNOWN);

Expand Down Expand Up @@ -896,7 +896,7 @@ mod tests {
(let (v9 u64) (const.u64 1))
(store.local local0 v2)
(store.local local1 v3)
(let (v10 u32) (call (#foo #example) v6 v4 v7 v7 v9))
(let (v10 u32) (exec (#foo #example) v6 v4 v7 v7 v9))
(br (block 4)))
(block 2
Expand Down
14 changes: 14 additions & 0 deletions hir/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,20 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> {
into_first_result!(self.Unary(Opcode::IsOdd, Type::I1, value, span))
}

fn exec(mut self, callee: FunctionIdent, args: &[Value], span: SourceSpan) -> Inst {
let mut vlist = ValueList::default();
{
let dfg = self.data_flow_graph_mut();
assert!(
dfg.get_import(&callee).is_some(),
"must import callee ({}) before calling it",
&callee
);
vlist.extend(args.iter().copied(), &mut dfg.value_lists);
}
self.Call(Opcode::Exec, callee, vlist, span).0
}

fn call(mut self, callee: FunctionIdent, args: &[Value], span: SourceSpan) -> Inst {
let mut vlist = ValueList::default();
{
Expand Down
13 changes: 8 additions & 5 deletions hir/src/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ pub enum Opcode {
IsOdd,
Min,
Max,
Exec,
Call,
Syscall,
Br,
Expand Down Expand Up @@ -485,7 +486,7 @@ impl Opcode {
}

pub fn is_call(&self) -> bool {
matches!(self, Self::Call | Self::Syscall)
matches!(self, Self::Exec | Self::Call | Self::Syscall)
}

pub fn is_commutative(&self) -> bool {
Expand Down Expand Up @@ -517,7 +518,7 @@ impl Opcode {
| Self::MemCpy
| Self::MemSize
| Self::Load
| Self::Call
| Self::Exec
| Self::Syscall
| Self::InlineAsm
| Self::Reload
Expand All @@ -531,7 +532,7 @@ impl Opcode {
| Self::MemSet
| Self::MemCpy
| Self::Store
| Self::Call
| Self::Exec
| Self::Syscall
| Self::InlineAsm
| Self::Spill
Expand All @@ -549,6 +550,7 @@ impl Opcode {
| Self::MemGrow
| Self::MemSet
| Self::MemCpy
| Self::Exec
| Self::Call
| Self::Syscall
| Self::Br
Expand Down Expand Up @@ -701,7 +703,7 @@ impl Opcode {
// memcpy requires source, destination, and arity
Self::MemSet | Self::MemCpy => 3,
// Calls are entirely variable
Self::Call | Self::Syscall => 0,
Self::Exec | Self::Call | Self::Syscall => 0,
// Unconditional branches have no fixed arguments
Self::Br => 0,
// Ifs have a single argument, the conditional
Expand Down Expand Up @@ -811,7 +813,7 @@ impl Opcode {
smallvec![ctrl_ty.pointee().expect("expected pointer type").clone()]
}
// Call results are handled separately
Self::Call | Self::Syscall | Self::InlineAsm => unreachable!(),
Self::Exec | Self::Call | Self::Syscall | Self::InlineAsm => unreachable!(),
}
}
}
Expand Down Expand Up @@ -852,6 +854,7 @@ impl fmt::Display for Opcode {
Self::Br => f.write_str("br"),
Self::CondBr => f.write_str("condbr"),
Self::Switch => f.write_str("switch"),
Self::Exec => f.write_str("exec"),
Self::Call => f.write_str("call"),
Self::Syscall => f.write_str("syscall"),
Self::Ret => f.write_str("ret"),
Expand Down
Loading

0 comments on commit 7c99e32

Please sign in to comment.