From f85f2ddb01dad7931290001cebcffbd39d48d26f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 29 May 2024 15:46:29 +0100 Subject: [PATCH 1/4] feat: Emission for Call nodes --- src/emit/ops.rs | 48 ++++++++++++-- ...vm__emit__test__emit_hugr_call@llvm14.snap | 35 +++++++++++ ...st__emit_hugr_call@pre-mem2reg@llvm14.snap | 62 +++++++++++++++++++ src/emit/test.rs | 23 ++++++- 4 files changed, 162 insertions(+), 6 deletions(-) create mode 100644 src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap create mode 100644 src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap diff --git a/src/emit/ops.rs b/src/emit/ops.rs index 27d02e0..fbe9950 100644 --- a/src/emit/ops.rs +++ b/src/emit/ops.rs @@ -2,8 +2,8 @@ use anyhow::{anyhow, Result}; use hugr::{ hugr::views::SiblingGraph, ops::{ - Case, Conditional, Const, Input, LoadConstant, MakeTuple, NamedOp, OpTag, OpTrait, OpType, - Output, Tag, UnpackTuple, Value, + Call, Case, Conditional, Const, Input, LoadConstant, MakeTuple, NamedOp, OpTag, OpTrait, + OpType, Output, Tag, UnpackTuple, Value, }, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, @@ -158,7 +158,7 @@ where let o = self.take_output()?; o.finish(self.builder(), args.inputs) } - _ => emit_optype(self.context, args) + _ => emit_optype(self.context, args), } } } @@ -246,7 +246,10 @@ fn get_exactly_one_sum_type(ts: impl IntoIterator) -> Result(context: &mut EmitFuncContext<'c, H>, v: &Value) -> Result> { +fn emit_value<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, H>, + v: &Value, +) -> Result> { match v { Value::Extension { e } => { let exts = context.extensions(); @@ -330,6 +333,41 @@ fn emit_load_constant<'c, H: HugrView>( args.outputs.finish(context.builder(), [r]) } +fn emit_call<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, H>, + args: EmitOpArgs<'c, Call, H>, +) -> Result<()> { + if !args.node.called_function_type().params().is_empty() { + todo!("Call of generic function"); + } + let (func_node, _) = args + .node + .single_linked_output(args.node.called_function_port()) + .unwrap(); + let func = match func_node.get() { + OpType::FuncDecl(_) => context.get_func_decl(func_node.try_into_ot().unwrap()), + OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()), + _ => Err(anyhow!("emit_call: Not a Decl or Defn")), + }; + let inputs: Vec<_> = args.inputs.iter().map(|&x| x.into()).collect(); + let call = context.builder().build_call(func?, inputs.as_slice(), "")?; + let rets = match args.outputs.len() { + 0 => vec![], + 1 => vec![call.try_as_basic_value().expect_left("non-void")], + n => call + .try_as_basic_value() + .expect_left("non-void") + .into_struct_value() + .get_fields() + // For some reason `get_fields()` returns an extra field at the end with the type of + // a pointer to the struct??? Just take the first n fields until we figure out what's + // going on... + .take(n) + .collect(), + }; + args.outputs.finish(context.builder(), rets) +} + fn emit_optype<'c, H: HugrView>( context: &mut EmitFuncContext<'c, H>, args: EmitOpArgs<'c, OpType, H>, @@ -348,10 +386,10 @@ fn emit_optype<'c, H: HugrView>( } OpType::Const(_) => Ok(()), OpType::LoadConstant(ref lc) => emit_load_constant(context, args.into_ot(lc)), + OpType::Call(ref cl) => emit_call(context, args.into_ot(cl)), OpType::Conditional(ref co) => emit_conditional(context, args.into_ot(co)), // OpType::FuncDefn(fd) => self.emit(ot.into_ot(fd), context, inputs, outputs), _ => todo!("Unimplemented OpTypeEmitter: {}", args.node().name()), } } - diff --git a/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap new file mode 100644 index 0000000..781978a --- /dev/null +++ b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap @@ -0,0 +1,35 @@ +--- +source: src/emit/test.rs +expression: module.to_string() +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define void @_hl.main_void.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + call void @_hl.main_void.1() + ret void +} + +define { {}, {} } @_hl.main_unary.5({ {}, {} } %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %1 = call { {}, {} } @_hl.main_unary.5({ {}, {} } %0) + ret { {}, {} } %1 +} + +define { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %2 = call { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1) + %mrv = insertvalue { { {}, {} }, { {}, {} } } undef, { {}, {} } %0, 0 + %mrv7 = insertvalue { { {}, {} }, { {}, {} } } %mrv, { {}, {} } %1, 1 + ret { { {}, {} }, { {}, {} } } %mrv7 +} diff --git a/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap new file mode 100644 index 0000000..a50543f --- /dev/null +++ b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap @@ -0,0 +1,62 @@ +--- +source: src/emit/test.rs +expression: module.to_string() +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define void @_hl.main_void.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + call void @_hl.main_void.1() + ret void +} + +define { {}, {} } @_hl.main_unary.5({ {}, {} } %0) { +alloca_block: + %"0" = alloca { {}, {} }, align 8 + %"6_0" = alloca { {}, {} }, align 8 + %"8_0" = alloca { {}, {} }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store { {}, {} } %0, { {}, {} }* %"6_0", align 1 + %"6_01" = load { {}, {} }, { {}, {} }* %"6_0", align 1 + %1 = call { {}, {} } @_hl.main_unary.5({ {}, {} } %"6_01") + store { {}, {} } %1, { {}, {} }* %"8_0", align 1 + %"8_02" = load { {}, {} }, { {}, {} }* %"8_0", align 1 + store { {}, {} } %"8_02", { {}, {} }* %"0", align 1 + %"03" = load { {}, {} }, { {}, {} }* %"0", align 1 + ret { {}, {} } %"03" +} + +define { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1) { +alloca_block: + %"0" = alloca { {}, {} }, align 8 + %"1" = alloca { {}, {} }, align 8 + %"10_0" = alloca { {}, {} }, align 8 + %"10_1" = alloca { {}, {} }, align 8 + %"12_0" = alloca { {}, {} }, align 8 + %"12_1" = alloca { {}, {} }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store { {}, {} } %0, { {}, {} }* %"10_0", align 1 + store { {}, {} } %1, { {}, {} }* %"10_1", align 1 + %"10_01" = load { {}, {} }, { {}, {} }* %"10_0", align 1 + %"10_12" = load { {}, {} }, { {}, {} }* %"10_1", align 1 + %2 = call { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %"10_01", { {}, {} } %"10_12") + store { {}, {} } %"10_01", { {}, {} }* %"12_0", align 1 + store { {}, {} } %"10_12", { {}, {} }* %"12_1", align 1 + %"12_03" = load { {}, {} }, { {}, {} }* %"12_0", align 1 + %"12_14" = load { {}, {} }, { {}, {} }* %"12_1", align 1 + store { {}, {} } %"12_03", { {}, {} }* %"0", align 1 + store { {}, {} } %"12_14", { {}, {} }* %"1", align 1 + %"05" = load { {}, {} }, { {}, {} }* %"0", align 1 + %"16" = load { {}, {} }, { {}, {} }* %"1", align 1 + %mrv = insertvalue { { {}, {} }, { {}, {} } } undef, { {}, {} } %"05", 0 + %mrv7 = insertvalue { { {}, {} }, { {}, {} } } %mrv, { {}, {} } %"16", 1 + ret { { {}, {} }, { {}, {} } } %mrv7 +} diff --git a/src/emit/test.rs b/src/emit/test.rs index 573da09..ce82689 100644 --- a/src/emit/test.rs +++ b/src/emit/test.rs @@ -11,8 +11,8 @@ use hugr::ops::{Module, Tag, UnpackTuple, Value}; use hugr::std_extensions::arithmetic::int_ops::{self, INT_OPS_REGISTRY}; use hugr::std_extensions::arithmetic::int_types::ConstInt; use hugr::types::{Type, TypeRow}; -use hugr::Hugr; use hugr::{builder::DataflowSubContainer, types::FunctionType}; +use hugr::{type_row, Hugr}; use inkwell::passes::PassManager; use insta::assert_snapshot; use itertools::Itertools; @@ -227,6 +227,27 @@ fn emit_hugr_load_constant(#[with(-1, add_int_extensions)] llvm_ctx: TestContext check_emission!(hugr, llvm_ctx); } +#[rstest] +fn emit_hugr_call(llvm_ctx: TestContext) { + fn build_recursive(mod_b: &mut ModuleBuilder, name: &str, io: TypeRow) { + let f_id = mod_b + .declare(name, FunctionType::new_endo(io).into()) + .unwrap(); + let mut func_b = mod_b.define_declaration(&f_id).unwrap(); + let call = func_b + .call(&f_id, &[], func_b.input_wires(), &EMPTY_REG) + .unwrap(); + func_b.finish_with_outputs(call.outputs()).unwrap(); + } + + let mut mod_b = ModuleBuilder::new(); + build_recursive(&mut mod_b, "main_void", type_row![]); + build_recursive(&mut mod_b, "main_unary", type_row![BOOL_T]); + build_recursive(&mut mod_b, "main_binary", type_row![BOOL_T, BOOL_T]); + let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap(); + check_emission!(hugr, llvm_ctx); +} + #[rstest] fn emit_hugr_custom_op(#[with(-1, add_int_extensions)] llvm_ctx: TestContext) { let v1 = ConstInt::new_s(4, -24).unwrap(); From 554a0f67e7a50030d57be8565248e59a4027e282 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 29 May 2024 17:19:11 +0100 Subject: [PATCH 2/4] Expect right in void case --- src/emit/ops.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/emit/ops.rs b/src/emit/ops.rs index fbe9950..3a955f0 100644 --- a/src/emit/ops.rs +++ b/src/emit/ops.rs @@ -350,12 +350,17 @@ fn emit_call<'c, H: HugrView>( _ => Err(anyhow!("emit_call: Not a Decl or Defn")), }; let inputs: Vec<_> = args.inputs.iter().map(|&x| x.into()).collect(); - let call = context.builder().build_call(func?, inputs.as_slice(), "")?; + let call = context + .builder() + .build_call(func?, inputs.as_slice(), "")? + .try_as_basic_value(); let rets = match args.outputs.len() { - 0 => vec![], - 1 => vec![call.try_as_basic_value().expect_left("non-void")], + 0 => { + call.expect_right("void"); + vec![] + } + 1 => vec![call.expect_left("non-void")], n => call - .try_as_basic_value() .expect_left("non-void") .into_struct_value() .get_fields() From cbc46ce60c83b82d644480a3f5774a1d8a13f1aa Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 29 May 2024 17:46:45 +0100 Subject: [PATCH 3/4] Use build_extract_value to unpack return struct --- src/emit/ops.rs | 23 ++++++++----------- ...vm__emit__test__emit_hugr_call@llvm14.snap | 6 +++-- ...st__emit_hugr_call@pre-mem2reg@llvm14.snap | 6 +++-- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/emit/ops.rs b/src/emit/ops.rs index 3a955f0..a784a9c 100644 --- a/src/emit/ops.rs +++ b/src/emit/ops.rs @@ -350,27 +350,24 @@ fn emit_call<'c, H: HugrView>( _ => Err(anyhow!("emit_call: Not a Decl or Defn")), }; let inputs: Vec<_> = args.inputs.iter().map(|&x| x.into()).collect(); - let call = context - .builder() + let builder = context.builder(); + let call = builder .build_call(func?, inputs.as_slice(), "")? .try_as_basic_value(); - let rets = match args.outputs.len() { + let rets = match args.outputs.len() as u32 { 0 => { call.expect_right("void"); vec![] } 1 => vec![call.expect_left("non-void")], - n => call - .expect_left("non-void") - .into_struct_value() - .get_fields() - // For some reason `get_fields()` returns an extra field at the end with the type of - // a pointer to the struct??? Just take the first n fields until we figure out what's - // going on... - .take(n) - .collect(), + n => { + let return_struct = call.expect_left("non-void").into_struct_value(); + (0..n) + .map(|i| builder.build_extract_value(return_struct, i, "")) + .collect::, _>>()? + } }; - args.outputs.finish(context.builder(), rets) + args.outputs.finish(builder, rets) } fn emit_optype<'c, H: HugrView>( diff --git a/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap index 781978a..10e5d44 100644 --- a/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap +++ b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@llvm14.snap @@ -29,7 +29,9 @@ alloca_block: entry_block: ; preds = %alloca_block %2 = call { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %0, { {}, {} } %1) - %mrv = insertvalue { { {}, {} }, { {}, {} } } undef, { {}, {} } %0, 0 - %mrv7 = insertvalue { { {}, {} }, { {}, {} } } %mrv, { {}, {} } %1, 1 + %3 = extractvalue { { {}, {} }, { {}, {} } } %2, 0 + %4 = extractvalue { { {}, {} }, { {}, {} } } %2, 1 + %mrv = insertvalue { { {}, {} }, { {}, {} } } undef, { {}, {} } %3, 0 + %mrv7 = insertvalue { { {}, {} }, { {}, {} } } %mrv, { {}, {} } %4, 1 ret { { {}, {} }, { {}, {} } } %mrv7 } diff --git a/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap index a50543f..6d95704 100644 --- a/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap +++ b/src/emit/snapshots/hugr_llvm__emit__test__emit_hugr_call@pre-mem2reg@llvm14.snap @@ -48,8 +48,10 @@ entry_block: ; preds = %alloca_block %"10_01" = load { {}, {} }, { {}, {} }* %"10_0", align 1 %"10_12" = load { {}, {} }, { {}, {} }* %"10_1", align 1 %2 = call { { {}, {} }, { {}, {} } } @_hl.main_binary.9({ {}, {} } %"10_01", { {}, {} } %"10_12") - store { {}, {} } %"10_01", { {}, {} }* %"12_0", align 1 - store { {}, {} } %"10_12", { {}, {} }* %"12_1", align 1 + %3 = extractvalue { { {}, {} }, { {}, {} } } %2, 0 + %4 = extractvalue { { {}, {} }, { {}, {} } } %2, 1 + store { {}, {} } %3, { {}, {} }* %"12_0", align 1 + store { {}, {} } %4, { {}, {} }* %"12_1", align 1 %"12_03" = load { {}, {} }, { {}, {} }* %"12_0", align 1 %"12_14" = load { {}, {} }, { {}, {} }* %"12_1", align 1 store { {}, {} } %"12_03", { {}, {} }* %"0", align 1 From dfa0b03f5c86c7d55c0de59059ee6b0580a3318e Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Thu, 30 May 2024 09:05:49 +0100 Subject: [PATCH 4/4] Refactor inputs Co-authored-by: doug-q <141026920+doug-q@users.noreply.github.com> --- src/emit/ops.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/emit/ops.rs b/src/emit/ops.rs index a784a9c..e7ce208 100644 --- a/src/emit/ops.rs +++ b/src/emit/ops.rs @@ -349,7 +349,7 @@ fn emit_call<'c, H: HugrView>( OpType::FuncDefn(_) => context.get_func_defn(func_node.try_into_ot().unwrap()), _ => Err(anyhow!("emit_call: Not a Decl or Defn")), }; - let inputs: Vec<_> = args.inputs.iter().map(|&x| x.into()).collect(); + let inputs: args.inputs.into_iter().map_into().collect_vec(); let builder = context.builder(); let call = builder .build_call(func?, inputs.as_slice(), "")?