Skip to content

Commit

Permalink
feature: implement cross-context Call op, rename IR Exec back to Call
Browse files Browse the repository at this point in the history
  • Loading branch information
greenhat committed Nov 28, 2024
1 parent bf7b2ab commit 2a28cd2
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 22 deletions.
35 changes: 35 additions & 0 deletions codegen/masm/src/codegen/emit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ impl<'a> InstOpEmitter<'a> {
self.emitter.exec(import, span);
}

pub fn call(&mut self, callee: hir::FunctionIdent, span: SourceSpan) {
let import = self.dfg.get_import(&callee).unwrap();
self.emitter.call(import, span);
}

pub fn syscall(&mut self, callee: hir::FunctionIdent, span: SourceSpan) {
let import = self.dfg.get_import(&callee).unwrap();
self.emitter.syscall(import, span);
Expand Down Expand Up @@ -1890,6 +1895,36 @@ mod tests {
assert_eq!(emitter.stack()[0], return_ty);
}

#[test]
fn op_emitter_u32_call_test() {
use midenc_hir::ExternalFunction;

let mut function = setup();
let entry = function.body.id();
let mut stack = OperandStack::default();
let mut emitter = OpEmitter::new(&mut function, entry, &mut stack);

let return_ty = Type::Array(Box::new(Type::U32), 1);
let callee = ExternalFunction {
id: "test::add".parse().unwrap(),
signature: Signature::new(
[AbiParam::new(Type::U32), AbiParam::new(Type::I1)],
[AbiParam::new(return_ty.clone())],
),
};

let t = Immediate::I1(true);
let one = Immediate::U32(1);

emitter.literal(t, SourceSpan::default());
emitter.literal(one, SourceSpan::default());
assert_eq!(emitter.stack_len(), 2);

emitter.call(&callee, SourceSpan::default());
assert_eq!(emitter.stack_len(), 1);
assert_eq!(emitter.stack()[0], return_ty);
}

#[test]
fn op_emitter_u32_load_test() {
let mut function = setup();
Expand Down
107 changes: 107 additions & 0 deletions codegen/masm/src/codegen/emit/primop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,113 @@ impl<'a> OpEmitter<'a> {
self.emit(Op::Exec(callee), span);
}

/// Calls the given procedure via a `call` instruction.
///
/// A function called using this operation is invoked in the **new** memory context
pub fn call(&mut self, callee: &hir::ExternalFunction, span: SourceSpan) {
let import = callee;
let callee = import.id;
let signature = &import.signature;
for i in 0..signature.arity() {
let param = &signature.params[i];
let arg = self.stack.pop().expect("operand stack is empty");
let ty = arg.ty();
// Validate the purpose matches
match param.purpose {
ArgumentPurpose::StructReturn => {
panic!(
"sret parameters are not supported in call instructions since they don't \
make sense in a new memory context"
);
}
ArgumentPurpose::Default => (),
}
// Validate that the argument type is valid for the parameter ABI
match param.extension {
// Types must match exactly
ArgumentExtension::None => {
assert_eq!(
ty, param.ty,
"invalid call to {callee}: invalid argument type for parameter at index \
{i}"
);
}
// Caller can provide a smaller type which will be zero-extended to the expected
// type
//
// However, the argument must be an unsigned integer, and of smaller or equal size
// in order for the types to differ
ArgumentExtension::Zext if ty != param.ty => {
assert!(
param.ty.is_unsigned_integer(),
"invalid function signature: zero-extension is only valid for unsigned \
integer types"
);
assert!(
ty.is_unsigned_integer(),
"invalid call to {callee}: invalid argument type for parameter at index \
{i}, expected unsigned integer type, got {ty}"
);
let expected_size = param.ty.size_in_bits();
let provided_size = param.ty.size_in_bits();
assert!(
provided_size <= expected_size,
"invalid call to {callee}: invalid argument type for parameter at index \
{i}, expected integer width to be <= {expected_size} bits"
);
// Zero-extend this argument
self.stack.push(arg);
self.zext(&param.ty, span);
self.stack.drop();
}
// Caller can provide a smaller type which will be sign-extended to the expected
// type
//
// However, the argument must be an integer which can fit in the range of the
// expected type
ArgumentExtension::Sext if ty != param.ty => {
assert!(
param.ty.is_signed_integer(),
"invalid function signature: sign-extension is only valid for signed \
integer types"
);
assert!(
ty.is_integer(),
"invalid call to {callee}: invalid argument type for parameter at index \
{i}, expected integer type, got {ty}"
);
let expected_size = param.ty.size_in_bits();
let provided_size = param.ty.size_in_bits();
if ty.is_unsigned_integer() {
assert!(
provided_size < expected_size,
"invalid call to {callee}: invalid argument type for parameter at \
index {i}, expected unsigned integer width to be < {expected_size} \
bits"
);
} else {
assert!(
provided_size <= expected_size,
"invalid call to {callee}: invalid argument type for parameter at \
index {i}, expected integer width to be <= {expected_size} bits"
);
}
// Push the operand back on the stack for `sext`
self.stack.push(arg);
self.sext(&param.ty, span);
self.stack.drop();
}
ArgumentExtension::Zext | ArgumentExtension::Sext => (),
}
}

for result in signature.results.iter().rev() {
self.push(result.ty.clone());
}

self.emit(Op::Call(callee), span);
}

/// Execute the given procedure as a syscall.
///
/// A function called using this operation is invoked in the same memory context as the caller.
Expand Down
5 changes: 3 additions & 2 deletions codegen/masm/src/codegen/emitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ impl<'b, 'f: 'b> BlockEmitter<'b, 'f> {
Instruction::Load(op) => self.emit_load_op(inst_info, op),
Instruction::PrimOp(op) => self.emit_primop(inst_info, op),
Instruction::PrimOpImm(op) => self.emit_primop_imm(inst_info, op),
Instruction::Exec(op) => self.emit_call_op(inst_info, op),
Instruction::Call(op) => self.emit_call_op(inst_info, op),
Instruction::InlineAsm(op) => self.emit_inline_asm(inst_info, op),
Instruction::Switch(_) => {
panic!("expected switch instructions to have been rewritten before stackification")
Expand Down Expand Up @@ -811,14 +811,15 @@ impl<'b, 'f: 'b> BlockEmitter<'b, 'f> {
}
}

fn emit_call_op(&mut self, inst_info: &InstInfo, op: &hir::Exec) {
fn emit_call_op(&mut self, inst_info: &InstInfo, op: &hir::Call) {
assert_ne!(op.callee, self.function.f.id, "unexpected recursive call");

let span = self.function.f.dfg.inst_span(inst_info.inst);
let mut emitter = self.inst_emitter(inst_info.inst);
match op.op {
hir::Opcode::Syscall => emitter.syscall(op.callee, span),
hir::Opcode::Exec => emitter.exec(op.callee, span),
hir::Opcode::Call => emitter.call(op.callee, span),
opcode => unimplemented!("unrecognized procedure call opcode: '{opcode}'"),
}
}
Expand Down
6 changes: 3 additions & 3 deletions hir-analysis/src/validation/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl<'a> Rule<BlockData> for TypeCheck<'a> {
| Instruction::PrimOp(_)
| Instruction::Test(_)
| Instruction::InlineAsm(_)
| Instruction::Exec(_) => {
| Instruction::Call(_) => {
let args = node.arguments(&self.dfg.value_lists);
typechecker.check(args, results)?;
}
Expand Down Expand Up @@ -1233,8 +1233,8 @@ impl<'a> InstTypeChecker<'a> {
}
Opcode::IsOdd => InstPattern::Exact(vec![TypePattern::Int], vec![Type::I1.into()]),
Opcode::Min | Opcode::Max => InstPattern::BinaryMatching(TypePattern::Int),
Opcode::Exec | Opcode::Syscall => match node.as_ref() {
Instruction::Exec(Exec { ref callee, .. }) => {
Opcode::Exec | Opcode::Call | Opcode::Syscall => match node.as_ref() {
Instruction::Call(Call { ref callee, .. }) => {
if let Some(import) = dfg.get_import(callee) {
let args = import
.signature
Expand Down
22 changes: 18 additions & 4 deletions hir/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,21 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> {
);
vlist.extend(args.iter().copied(), &mut dfg.value_lists);
}
self.Exec(Opcode::Exec, callee, vlist, span).0
self.Call(Opcode::Exec, callee, vlist, span).0
}

fn call(mut self, callee: FunctionIdent, args: &[Value], span: SourceSpan) -> Inst {
let mut vlist = ValueList::default();
{
let dfg = self.data_flow_graph_mut();
assert!(
dfg.get_import(&callee).is_some(),
"must import callee ({}) before calling it",
&callee
);
vlist.extend(args.iter().copied(), &mut dfg.value_lists);
}
self.Call(Opcode::Call, callee, vlist, span).0
}

fn syscall(mut self, callee: FunctionIdent, args: &[Value], span: SourceSpan) -> Inst {
Expand All @@ -1370,7 +1384,7 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> {
);
vlist.extend(args.iter().copied(), &mut dfg.value_lists);
}
self.Exec(Opcode::Syscall, callee, vlist, span).0
self.Call(Opcode::Syscall, callee, vlist, span).0
}

fn select(mut self, cond: Value, a: Value, b: Value, span: SourceSpan) -> Value {
Expand Down Expand Up @@ -1524,14 +1538,14 @@ pub trait InstBuilder<'f>: InstBuilderBase<'f> {
}

#[allow(non_snake_case)]
fn Exec(
fn Call(
self,
op: Opcode,
callee: FunctionIdent,
args: ValueList,
span: SourceSpan,
) -> (Inst, &'f mut DataFlowGraph) {
let data = Instruction::Exec(Exec { op, callee, args });
let data = Instruction::Call(Call { op, callee, args });
self.build(data, Type::Unit, span)
}

Expand Down
27 changes: 15 additions & 12 deletions hir/src/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub enum Instruction {
BinaryOpImm(BinaryOpImm),
UnaryOp(UnaryOp),
UnaryOpImm(UnaryOpImm),
Exec(Exec),
Call(Call),
Br(Br),
CondBr(CondBr),
Switch(Switch),
Expand All @@ -123,7 +123,7 @@ impl Instruction {
Self::BinaryOpImm(op) => Self::BinaryOpImm(op.clone()),
Self::UnaryOp(op) => Self::UnaryOp(op.clone()),
Self::UnaryOpImm(op) => Self::UnaryOpImm(op.clone()),
Self::Exec(call) => Self::Exec(Exec {
Self::Call(call) => Self::Call(Call {
args: call.args.deep_clone(value_lists),
..call.clone()
}),
Expand Down Expand Up @@ -178,7 +178,7 @@ impl Instruction {
| Self::BinaryOpImm(BinaryOpImm { ref op, .. })
| Self::UnaryOp(UnaryOp { ref op, .. })
| Self::UnaryOpImm(UnaryOpImm { ref op, .. })
| Self::Exec(Exec { ref op, .. })
| Self::Call(Call { ref op, .. })
| Self::Br(Br { ref op, .. })
| Self::CondBr(CondBr { ref op, .. })
| Self::Switch(Switch { ref op, .. })
Expand Down Expand Up @@ -234,7 +234,7 @@ impl Instruction {
Self::BinaryOp(BinaryOp { ref args, .. }) => args.as_slice(),
Self::BinaryOpImm(BinaryOpImm { ref arg, .. }) => core::slice::from_ref(arg),
Self::UnaryOp(UnaryOp { ref arg, .. }) => core::slice::from_ref(arg),
Self::Exec(Exec { ref args, .. }) => args.as_slice(pool),
Self::Call(Call { ref args, .. }) => args.as_slice(pool),
Self::CondBr(CondBr { ref cond, .. }) => core::slice::from_ref(cond),
Self::Switch(Switch { ref arg, .. }) => core::slice::from_ref(arg),
Self::Ret(Ret { ref args, .. }) => args.as_slice(pool),
Expand All @@ -253,7 +253,7 @@ impl Instruction {
Self::BinaryOp(BinaryOp { ref mut args, .. }) => args.as_mut_slice(),
Self::BinaryOpImm(BinaryOpImm { ref mut arg, .. }) => core::slice::from_mut(arg),
Self::UnaryOp(UnaryOp { ref mut arg, .. }) => core::slice::from_mut(arg),
Self::Exec(Exec { ref mut args, .. }) => args.as_mut_slice(pool),
Self::Call(Call { ref mut args, .. }) => args.as_mut_slice(pool),
Self::CondBr(CondBr { ref mut cond, .. }) => core::slice::from_mut(cond),
Self::Switch(Switch { ref mut arg, .. }) => core::slice::from_mut(arg),
Self::Ret(Ret { ref mut args, .. }) => args.as_mut_slice(pool),
Expand Down Expand Up @@ -298,7 +298,7 @@ impl Instruction {

pub fn analyze_call<'a>(&'a self, pool: &'a ValueListPool) -> CallInfo<'a> {
match self {
Self::Exec(ref c) => CallInfo::Direct(c.callee, c.args.as_slice(pool)),
Self::Call(ref c) => CallInfo::Direct(c.callee, c.args.as_slice(pool)),
_ => CallInfo::NotACall,
}
}
Expand Down Expand Up @@ -458,6 +458,7 @@ pub enum Opcode {
Min,
Max,
Exec,
Call,
Syscall,
Br,
CondBr,
Expand Down Expand Up @@ -485,7 +486,7 @@ impl Opcode {
}

pub fn is_call(&self) -> bool {
matches!(self, Self::Exec | Self::Syscall)
matches!(self, Self::Exec | Self::Call | Self::Syscall)
}

pub fn is_commutative(&self) -> bool {
Expand Down Expand Up @@ -550,6 +551,7 @@ impl Opcode {
| Self::MemSet
| Self::MemCpy
| Self::Exec
| Self::Call
| Self::Syscall
| Self::Br
| Self::CondBr
Expand Down Expand Up @@ -701,7 +703,7 @@ impl Opcode {
// memcpy requires source, destination, and arity
Self::MemSet | Self::MemCpy => 3,
// Calls are entirely variable
Self::Exec | Self::Syscall => 0,
Self::Exec | Self::Call | Self::Syscall => 0,
// Unconditional branches have no fixed arguments
Self::Br => 0,
// Ifs have a single argument, the conditional
Expand Down Expand Up @@ -811,7 +813,7 @@ impl Opcode {
smallvec![ctrl_ty.pointee().expect("expected pointer type").clone()]
}
// Call results are handled separately
Self::Exec | Self::Syscall | Self::InlineAsm => unreachable!(),
Self::Exec | Self::Call | Self::Syscall | Self::InlineAsm => unreachable!(),
}
}
}
Expand Down Expand Up @@ -853,6 +855,7 @@ impl fmt::Display for Opcode {
Self::CondBr => f.write_str("condbr"),
Self::Switch => f.write_str("switch"),
Self::Exec => f.write_str("exec"),
Self::Call => f.write_str("call"),
Self::Syscall => f.write_str("syscall"),
Self::Ret => f.write_str("ret"),
Self::Test => f.write_str("test"),
Expand Down Expand Up @@ -1002,7 +1005,7 @@ pub struct UnaryOpImm {
}

#[derive(Debug, Clone)]
pub struct Exec {
pub struct Call {
pub op: Opcode,
pub callee: FunctionIdent,
/// NOTE: Call arguments are always in stack order, i.e. the top operand on
Expand Down Expand Up @@ -1191,7 +1194,7 @@ impl<'a> PartialEq for InstructionWithValueListPool<'a> {
(Instruction::UnaryOpImm(l), Instruction::UnaryOpImm(r)) => {
l.imm == r.imm && l.overflow == r.overflow
}
(Instruction::Exec(l), Instruction::Exec(r)) => {
(Instruction::Call(l), Instruction::Call(r)) => {
l.callee == r.callee
&& l.args.as_slice(self.value_lists) == r.args.as_slice(self.value_lists)
}
Expand Down Expand Up @@ -1354,7 +1357,7 @@ impl<'a> formatter::PrettyPrint for InstPrettyPrinter<'a> {
(vec![], args)
}
Instruction::RetImm(RetImm { arg, .. }) => (vec![], vec![display(*arg)]),
Instruction::Exec(Exec { callee, args, .. }) => {
Instruction::Call(Call { callee, args, .. }) => {
let mut operands = if callee.module == self.current_function.module {
vec![display(callee.function)]
} else {
Expand Down
2 changes: 1 addition & 1 deletion hir/src/parser/ast/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ fn try_insert_inst(
operands.iter().map(|arg| arg.into_inner()),
&mut function.dfg.value_lists,
);
Some(Instruction::Exec(crate::Exec { op, callee, args }))
Some(Instruction::Call(crate::Call { op, callee, args }))
} else {
None
}
Expand Down
Loading

0 comments on commit 2a28cd2

Please sign in to comment.