Skip to content

Commit

Permalink
Merge pull request #11 from mark-koch/feat/call
Browse files Browse the repository at this point in the history
feat: Emission for Call nodes
  • Loading branch information
doug-q authored May 30, 2024
2 parents 8e30350 + dfa0b03 commit 7609d88
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 6 deletions.
50 changes: 45 additions & 5 deletions src/emit/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use anyhow::{anyhow, Result};
use hugr::{
hugr::views::SiblingGraph,
ops::{
Case, Conditional, Const, Input, LoadConstant, MakeTuple, NamedOp, OpTag, OpTrait, OpType,
Output, Tag, UnpackTuple, Value,
Call, Case, Conditional, Const, Input, LoadConstant, MakeTuple, NamedOp, OpTag, OpTrait,
OpType, Output, Tag, UnpackTuple, Value,
},
types::{SumType, Type, TypeEnum},
HugrView, NodeIndex,
Expand Down Expand Up @@ -158,7 +158,7 @@ where
let o = self.take_output()?;
o.finish(self.builder(), args.inputs)
}
_ => emit_optype(self.context, args)
_ => emit_optype(self.context, args),
}
}
}
Expand Down Expand Up @@ -246,7 +246,10 @@ fn get_exactly_one_sum_type(ts: impl IntoIterator<Item = Type>) -> Result<SumTyp
Ok(sum_type)
}

fn emit_value<'c, H: HugrView>(context: &mut EmitFuncContext<'c, H>, v: &Value) -> Result<BasicValueEnum<'c>> {
fn emit_value<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
v: &Value,
) -> Result<BasicValueEnum<'c>> {
match v {
Value::Extension { e } => {
let exts = context.extensions();
Expand Down Expand Up @@ -330,6 +333,43 @@ fn emit_load_constant<'c, H: HugrView>(
args.outputs.finish(context.builder(), [r])
}

fn emit_call<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, Call, H>,
) -> Result<()> {
if !args.node.called_function_type().params().is_empty() {
todo!("Call of generic function");
}
let (func_node, _) = args
.node
.single_linked_output(args.node.called_function_port())
.unwrap();
let func = match func_node.get() {
OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()),
OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()),
_ => Err(anyhow!("emit_call: Not a Decl or Defn")),
};
let inputs: args.inputs.into_iter().map_into().collect_vec();
let builder = context.builder();
let call = builder
.build_call(func?, inputs.as_slice(), "")?
.try_as_basic_value();
let rets = match args.outputs.len() as u32 {
0 => {
call.expect_right("void");
vec![]
}
1 => vec![call.expect_left("non-void")],
n => {
let return_struct = call.expect_left("non-void").into_struct_value();
(0..n)
.map(|i| builder.build_extract_value(return_struct, i, ""))
.collect::<Result<Vec<_>, _>>()?
}
};
args.outputs.finish(builder, rets)
}

fn emit_optype<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, H>,
args: EmitOpArgs<'c, OpType, H>,
Expand All @@ -348,10 +388,10 @@ fn emit_optype<'c, H: HugrView>(
}
OpType::Const(_) => Ok(()),
OpType::LoadConstant(ref lc) => emit_load_constant(context, args.into_ot(lc)),
OpType::Call(ref cl) => emit_call(context, args.into_ot(cl)),
OpType::Conditional(ref co) => emit_conditional(context, args.into_ot(co)),

// OpType::FuncDefn(fd) => self.emit(ot.into_ot(fd), context, inputs, outputs),
_ => todo!("Unimplemented OpTypeEmitter: {}", args.node().name()),
}
}

37 changes: 37 additions & 0 deletions src/emit/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
source: src/emit/test.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define void @_hl.main_void.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
call void @_hl.main_void.1()
ret void
}

define { {}, {} } @_hl.main_unary.5({ {}, {} } %0) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%1 = call { {}, {} } @_hl.main_unary.5({ {}, {} } %0)
ret { {}, {} } %1
}

define { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%2 = call { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1)
%3 = extractvalue { { {}, {} }, { {}, {} } } %2, 0
%4 = extractvalue { { {}, {} }, { {}, {} } } %2, 1
%mrv = insertvalue { { {}, {} }, { {}, {} } } undef, { {}, {} } %3, 0
%mrv7 = insertvalue { { {}, {} }, { {}, {} } } %mrv, { {}, {} } %4, 1
ret { { {}, {} }, { {}, {} } } %mrv7
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
source: src/emit/test.rs
expression: module.to_string()
---
; ModuleID = 'test_context'
source_filename = "test_context"

define void @_hl.main_void.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
call void @_hl.main_void.1()
ret void
}

define { {}, {} } @_hl.main_unary.5({ {}, {} } %0) {
alloca_block:
%"0" = alloca { {}, {} }, align 8
%"6_0" = alloca { {}, {} }, align 8
%"8_0" = alloca { {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store { {}, {} } %0, { {}, {} }* %"6_0", align 1
%"6_01" = load { {}, {} }, { {}, {} }* %"6_0", align 1
%1 = call { {}, {} } @_hl.main_unary.5({ {}, {} } %"6_01")
store { {}, {} } %1, { {}, {} }* %"8_0", align 1
%"8_02" = load { {}, {} }, { {}, {} }* %"8_0", align 1
store { {}, {} } %"8_02", { {}, {} }* %"0", align 1
%"03" = load { {}, {} }, { {}, {} }* %"0", align 1
ret { {}, {} } %"03"
}

define { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1) {
alloca_block:
%"0" = alloca { {}, {} }, align 8
%"1" = alloca { {}, {} }, align 8
%"10_0" = alloca { {}, {} }, align 8
%"10_1" = alloca { {}, {} }, align 8
%"12_0" = alloca { {}, {} }, align 8
%"12_1" = alloca { {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store { {}, {} } %0, { {}, {} }* %"10_0", align 1
store { {}, {} } %1, { {}, {} }* %"10_1", align 1
%"10_01" = load { {}, {} }, { {}, {} }* %"10_0", align 1
%"10_12" = load { {}, {} }, { {}, {} }* %"10_1", align 1
%2 = call { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %"10_01", { {}, {} } %"10_12")
%3 = extractvalue { { {}, {} }, { {}, {} } } %2, 0
%4 = extractvalue { { {}, {} }, { {}, {} } } %2, 1
store { {}, {} } %3, { {}, {} }* %"12_0", align 1
store { {}, {} } %4, { {}, {} }* %"12_1", align 1
%"12_03" = load { {}, {} }, { {}, {} }* %"12_0", align 1
%"12_14" = load { {}, {} }, { {}, {} }* %"12_1", align 1
store { {}, {} } %"12_03", { {}, {} }* %"0", align 1
store { {}, {} } %"12_14", { {}, {} }* %"1", align 1
%"05" = load { {}, {} }, { {}, {} }* %"0", align 1
%"16" = load { {}, {} }, { {}, {} }* %"1", align 1
%mrv = insertvalue { { {}, {} }, { {}, {} } } undef, { {}, {} } %"05", 0
%mrv7 = insertvalue { { {}, {} }, { {}, {} } } %mrv, { {}, {} } %"16", 1
ret { { {}, {} }, { {}, {} } } %mrv7
}
23 changes: 22 additions & 1 deletion src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use hugr::ops::{Module, Tag, UnpackTuple, Value};
use hugr::std_extensions::arithmetic::int_ops::{self, INT_OPS_REGISTRY};
use hugr::std_extensions::arithmetic::int_types::ConstInt;
use hugr::types::{Type, TypeRow};
use hugr::Hugr;
use hugr::{builder::DataflowSubContainer, types::FunctionType};
use hugr::{type_row, Hugr};
use inkwell::passes::PassManager;
use insta::assert_snapshot;
use itertools::Itertools;
Expand Down Expand Up @@ -227,6 +227,27 @@ fn emit_hugr_load_constant(#[with(-1, add_int_extensions)] llvm_ctx: TestContext
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn emit_hugr_call(llvm_ctx: TestContext) {
fn build_recursive(mod_b: &mut ModuleBuilder<Hugr>, name: &str, io: TypeRow) {
let f_id = mod_b
.declare(name, FunctionType::new_endo(io).into())
.unwrap();
let mut func_b = mod_b.define_declaration(&f_id).unwrap();
let call = func_b
.call(&f_id, &[], func_b.input_wires(), &EMPTY_REG)
.unwrap();
func_b.finish_with_outputs(call.outputs()).unwrap();
}

let mut mod_b = ModuleBuilder::new();
build_recursive(&mut mod_b, "main_void", type_row![]);
build_recursive(&mut mod_b, "main_unary", type_row![BOOL_T]);
build_recursive(&mut mod_b, "main_binary", type_row![BOOL_T, BOOL_T]);
let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap();
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn emit_hugr_custom_op(#[with(-1, add_int_extensions)] llvm_ctx: TestContext) {
let v1 = ConstInt::new_s(4, -24).unwrap();
Expand Down

0 comments on commit 7609d88

Please sign in to comment.