Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Dec 11, 2024
1 parent 7d733eb commit 4bad836
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 46 deletions.
15 changes: 14 additions & 1 deletion hugr-llvm/src/emit/func/mailbox.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, rc::Rc};

use anyhow::Result;
use anyhow::{bail, Result};
use delegate::delegate;
use inkwell::{
builder::Builder,
Expand Down Expand Up @@ -148,6 +148,19 @@ impl<'c> RowMailBox<'c> {
builder: &Builder<'c>,
vs: impl IntoIterator<Item = BasicValueEnum<'c>>,
) -> Result<()> {
let vs = vs.into_iter().collect_vec();
#[cfg(debug_assertions)]
{
let actual_types = vs.clone().into_iter().map(|x| x.get_type()).collect_vec();
let expected_types = self.get_types().collect_vec();
if actual_types != expected_types {
bail!(
"RowMailbox::write: Expected types {:?}, got {:?}",
expected_types,
actual_types
);
}
}
zip_eq(self.0.iter(), vs).try_for_each(|(mb, v)| mb.write(builder, v))
}

Expand Down
80 changes: 36 additions & 44 deletions hugr-llvm/src/emit/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,61 +314,53 @@ fn emit_tail_loop<'c, H: HugrView>(
) -> Result<()> {
// TODO: Switch on the tag in loop_body to see where to go next
// TODO: Handle "other" args
let node = args.node();


// Make a block to jump to when we `Break`
let out_bb = context.new_basic_block("loop_out", None);
// A block for the body of the loop
let body_bb = context.new_basic_block("loop_body", Some(out_bb));
// Pack input data into a sum type - do we need this?
let prep_bb = context.new_basic_block("loop_prep", Some(body_bb));

context.builder().build_unconditional_branch(prep_bb);

let sum_ty = SumType::new([args.node().just_inputs.clone(), args.node().just_outputs.clone()]);
let outs_rmb = context.node_outs_rmb(args.node)?;

{
let builder = context.builder();
builder.position_at_end(prep_bb);
let body_in_row = args.node.just_inputs.clone();
let body_in_len = body_in_row.len();
let body_in_tuple = context.llvm_sum_type(SumType::new_tuple(body_in_row))?;

let mut loop_inputs = args.inputs.clone();
let other_inputs = loop_inputs.split_off(body_in_len);
let loop_input_ptr = builder.build_alloca(body_in_tuple.clone(), "loop_input")?;
let (body_i_node, body_o_node) = node.get_io().unwrap();
let body_i_rmb = context.node_outs_rmb(body_i_node)?;
let body_o_rmb = context.node_ins_rmb(body_o_node)?;

let body_in_tup = body_in_tuple.build_tag(builder, 0, loop_inputs)?;
builder.build_store(loop_input_ptr, body_in_tup);
builder.build_unconditional_branch(body_bb);
body_i_rmb.write(context.builder(), args.inputs)?;
context.builder().build_unconditional_branch(body_bb)?;

builder.position_at_end(body_bb);
let control_llvm_sum_type = {
let sum_ty = SumType::new([node.just_inputs.clone(), node.just_outputs.clone()]);
context.llvm_sum_type(sum_ty)?
};


// Emit the body of the loop into the right block
let mut dfpe = DataflowParentEmitter::new(args);
dfpe.emit_children(context)?;

// After the body we need to unpack the row type, then jump to the right block
let builder = context.builder();
let output_vals: Vec<BasicValueEnum> = outs_rmb.read(builder, []).unwrap();
let output_types: Vec<_> = outs_rmb.get_types().collect();
let llvm_sum_ty = LLVMSumType::try_new(&context.typing_session(), sum_ty)?;

println!("{:?}", output_vals);
let sum_output = LLVMSumValue::try_new(output_vals[0], llvm_sum_ty)?;
let tag = sum_output.build_get_tag(builder)?;

let tag = IntValue::try_from(output_vals[0]).unwrap();
let continue_tag = context.iw_context().i64_type().const_int(0, false);
let break_tag = context.iw_context().i64_type().const_int(1, false);
// TODO: Make this a conditional branch instead of switch
builder.build_switch(tag, out_bb, &[(break_tag, out_bb), (continue_tag, prep_bb)]);
du
// Return Ok so we can see the insta emission with
// `cargo insta test` for debugging
context.build_positioned(body_bb, move |context| {
let inputs = body_i_rmb.read_vec(context.builder(), [])?;
emit_dataflow_parent(
context,
EmitOpArgs {
node,
inputs,
outputs: body_o_rmb.promise(),
},
)?;
let dataflow_outputs = body_o_rmb.read_vec(context.builder(), [])?;
let control_val = LLVMSumValue::try_new(dataflow_outputs[0], control_llvm_sum_type)?;
let mut outputs = Some(args.outputs);

control_val.build_destructure(context.builder(), |builder, tag, mut values| {
values.extend(dataflow_outputs[1..].iter().copied());
if tag == 0 {
body_i_rmb.write(builder, values)?;
builder.build_unconditional_branch(body_bb)?;
} else {
outputs.take().unwrap().finish(builder, values)?;
builder.build_unconditional_branch(out_bb)?;
}
Ok(())
})
})?;
context.builder().position_at_end(out_bb);
Ok(())
}

Expand Down
71 changes: 71 additions & 0 deletions hugr-llvm/src/emit/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
---
source: hugr-llvm/src/emit/test.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { { {} }, i64 } @_hl.main.1(i64 %0, i64 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
br label %loop_body

loop_body: ; preds = %20, %entry_block
%"5_0.0" = phi i64 [ %0, %entry_block ], [ %22, %20 ]
%"5_1.0" = phi i64 [ %1, %entry_block ], [ %"19.0", %20 ]
%2 = insertvalue { i64 } undef, i64 %"5_0.0", 0
%3 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %2, 1
%4 = extractvalue { i32, { i64 }, { { {} } } } %3, 0
switch i32 %4, label %5 [
i32 1, label %8
]

5: ; preds = %loop_body
%6 = extractvalue { i32, { i64 }, { { {} } } } %3, 1
%7 = extractvalue { i64 } %6, 0
br label %cond_12_case_0

8: ; preds = %loop_body
%9 = extractvalue { i32, { i64 }, { { {} } } } %3, 2
%10 = extractvalue { { {} } } %9, 0
br label %cond_12_case_1

loop_out: ; preds = %23
%mrv = insertvalue { { {} }, i64 } undef, { {} } %25, 0
%mrv40 = insertvalue { { {} }, i64 } %mrv, i64 %"19.0", 1
ret { { {} }, i64 } %mrv40

cond_12_case_0: ; preds = %5
%11 = mul i64 %"5_1.0", 2
%12 = sub i64 %7, 1
%13 = insertvalue { i64 } undef, i64 %12, 0
%14 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %13, 1
br label %cond_exit_12

cond_12_case_1: ; preds = %8
%15 = insertvalue { { {} } } undef, { {} } undef, 0
%16 = insertvalue { i32, { i64 }, { { {} } } } { i32 1, { i64 } poison, { { {} } } poison }, { { {} } } %15, 2
br label %cond_exit_12

cond_exit_12: ; preds = %cond_12_case_1, %cond_12_case_0
%"08.0" = phi { i32, { i64 }, { { {} } } } [ %14, %cond_12_case_0 ], [ %16, %cond_12_case_1 ]
%"19.0" = phi i64 [ %11, %cond_12_case_0 ], [ %"5_1.0", %cond_12_case_1 ]
%17 = icmp ule i64 %"5_0.0", 0
%18 = select i1 %17, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
%19 = extractvalue { i32, { i64 }, { { {} } } } %"08.0", 0
switch i32 %19, label %20 [
i32 1, label %23
]

20: ; preds = %cond_exit_12
%21 = extractvalue { i32, { i64 }, { { {} } } } %"08.0", 1
%22 = extractvalue { i64 } %21, 0
br label %loop_body

23: ; preds = %cond_exit_12
%24 = extractvalue { i32, { i64 }, { { {} } } } %"08.0", 2
%25 = extractvalue { { {} } } %24, 0
br label %loop_out
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
---
source: hugr-llvm/src/emit/test.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define { { {} }, i64 } @_hl.main.1(i64 %0, i64 %1) {
alloca_block:
%"0" = alloca { {} }, align 8
%"1" = alloca i64, align 8
%"2_0" = alloca i64, align 8
%"2_1" = alloca i64, align 8
%"4_0" = alloca { {} }, align 8
%"4_1" = alloca i64, align 8
%"5_0" = alloca i64, align 8
%"5_1" = alloca i64, align 8
%"12_0" = alloca { i32, { i64 }, { { {} } } }, align 8
%"12_1" = alloca i64, align 8
%"8_0" = alloca i64, align 8
%"10_0" = alloca { i32, { i64 }, { { {} } } }, align 8
%"08" = alloca { i32, { i64 }, { { {} } } }, align 8
%"19" = alloca i64, align 8
%"012" = alloca i64, align 8
%"113" = alloca i64, align 8
%"19_0" = alloca i64, align 8
%"17_0" = alloca i64, align 8
%"14_0" = alloca i64, align 8
%"14_1" = alloca i64, align 8
%"20_0" = alloca i64, align 8
%"21_0" = alloca i64, align 8
%"22_0" = alloca { i32, { i64 }, { { {} } } }, align 8
%"023" = alloca { {} }, align 8
%"124" = alloca i64, align 8
%"27_0" = alloca { {} }, align 8
%"28_0" = alloca { i32, { i64 }, { { {} } } }, align 8
%"24_0" = alloca { {} }, align 8
%"24_1" = alloca i64, align 8
%"9_0" = alloca { i32, {}, {} }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i64 %0, i64* %"2_0", align 4
store i64 %1, i64* %"2_1", align 4
%"2_01" = load i64, i64* %"2_0", align 4
%"2_12" = load i64, i64* %"2_1", align 4
store i64 %"2_01", i64* %"5_0", align 4
store i64 %"2_12", i64* %"5_1", align 4
br label %loop_body

loop_body: ; preds = %20, %entry_block
%"5_03" = load i64, i64* %"5_0", align 4
%"5_14" = load i64, i64* %"5_1", align 4
store i64 0, i64* %"8_0", align 4
store i64 %"5_03", i64* %"5_0", align 4
store i64 %"5_14", i64* %"5_1", align 4
%"5_05" = load i64, i64* %"5_0", align 4
%2 = insertvalue { i64 } undef, i64 %"5_05", 0
%3 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %2, 1
store { i32, { i64 }, { { {} } } } %3, { i32, { i64 }, { { {} } } }* %"10_0", align 4
%"10_06" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"10_0", align 4
%"5_17" = load i64, i64* %"5_1", align 4
%4 = extractvalue { i32, { i64 }, { { {} } } } %"10_06", 0
switch i32 %4, label %5 [
i32 1, label %8
]

5: ; preds = %loop_body
%6 = extractvalue { i32, { i64 }, { { {} } } } %"10_06", 1
%7 = extractvalue { i64 } %6, 0
store i64 %7, i64* %"012", align 4
store i64 %"5_17", i64* %"113", align 4
br label %cond_12_case_0

8: ; preds = %loop_body
%9 = extractvalue { i32, { i64 }, { { {} } } } %"10_06", 2
%10 = extractvalue { { {} } } %9, 0
store { {} } %10, { {} }* %"023", align 1
store i64 %"5_17", i64* %"124", align 4
br label %cond_12_case_1

loop_out: ; preds = %23
%"4_036" = load { {} }, { {} }* %"4_0", align 1
%"4_137" = load i64, i64* %"4_1", align 4
store { {} } %"4_036", { {} }* %"0", align 1
store i64 %"4_137", i64* %"1", align 4
%"038" = load { {} }, { {} }* %"0", align 1
%"139" = load i64, i64* %"1", align 4
%mrv = insertvalue { { {} }, i64 } undef, { {} } %"038", 0
%mrv40 = insertvalue { { {} }, i64 } %mrv, i64 %"139", 1
ret { { {} }, i64 } %mrv40

cond_12_case_0: ; preds = %5
%"014" = load i64, i64* %"012", align 4
%"115" = load i64, i64* %"113", align 4
store i64 2, i64* %"19_0", align 4
store i64 1, i64* %"17_0", align 4
store i64 %"014", i64* %"14_0", align 4
store i64 %"115", i64* %"14_1", align 4
%"14_116" = load i64, i64* %"14_1", align 4
%"19_017" = load i64, i64* %"19_0", align 4
%11 = mul i64 %"14_116", %"19_017"
store i64 %11, i64* %"20_0", align 4
%"14_018" = load i64, i64* %"14_0", align 4
%"17_019" = load i64, i64* %"17_0", align 4
%12 = sub i64 %"14_018", %"17_019"
store i64 %12, i64* %"21_0", align 4
%"21_020" = load i64, i64* %"21_0", align 4
%13 = insertvalue { i64 } undef, i64 %"21_020", 0
%14 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %13, 1
store { i32, { i64 }, { { {} } } } %14, { i32, { i64 }, { { {} } } }* %"22_0", align 4
%"22_021" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"22_0", align 4
%"20_022" = load i64, i64* %"20_0", align 4
store { i32, { i64 }, { { {} } } } %"22_021", { i32, { i64 }, { { {} } } }* %"08", align 4
store i64 %"20_022", i64* %"19", align 4
br label %cond_exit_12

cond_12_case_1: ; preds = %8
%"025" = load { {} }, { {} }* %"023", align 1
%"126" = load i64, i64* %"124", align 4
store { {} } undef, { {} }* %"27_0", align 1
%"27_027" = load { {} }, { {} }* %"27_0", align 1
%15 = insertvalue { { {} } } undef, { {} } %"27_027", 0
%16 = insertvalue { i32, { i64 }, { { {} } } } { i32 1, { i64 } poison, { { {} } } poison }, { { {} } } %15, 2
store { i32, { i64 }, { { {} } } } %16, { i32, { i64 }, { { {} } } }* %"28_0", align 4
store { {} } %"025", { {} }* %"24_0", align 1
store i64 %"126", i64* %"24_1", align 4
%"28_028" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"28_0", align 4
%"24_129" = load i64, i64* %"24_1", align 4
store { i32, { i64 }, { { {} } } } %"28_028", { i32, { i64 }, { { {} } } }* %"08", align 4
store i64 %"24_129", i64* %"19", align 4
br label %cond_exit_12

cond_exit_12: ; preds = %cond_12_case_1, %cond_12_case_0
%"010" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"08", align 4
%"111" = load i64, i64* %"19", align 4
store { i32, { i64 }, { { {} } } } %"010", { i32, { i64 }, { { {} } } }* %"12_0", align 4
store i64 %"111", i64* %"12_1", align 4
%"12_030" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"12_0", align 4
%"12_131" = load i64, i64* %"12_1", align 4
store { i32, { i64 }, { { {} } } } %"12_030", { i32, { i64 }, { { {} } } }* %"12_0", align 4
store i64 %"12_131", i64* %"12_1", align 4
%"5_032" = load i64, i64* %"5_0", align 4
%"8_033" = load i64, i64* %"8_0", align 4
%17 = icmp ule i64 %"5_032", %"8_033"
%18 = select i1 %17, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
store { i32, {}, {} } %18, { i32, {}, {} }* %"9_0", align 4
%"12_034" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"12_0", align 4
%"12_135" = load i64, i64* %"12_1", align 4
%19 = extractvalue { i32, { i64 }, { { {} } } } %"12_034", 0
switch i32 %19, label %20 [
i32 1, label %23
]

20: ; preds = %cond_exit_12
%21 = extractvalue { i32, { i64 }, { { {} } } } %"12_034", 1
%22 = extractvalue { i64 } %21, 0
store i64 %22, i64* %"5_0", align 4
store i64 %"12_135", i64* %"5_1", align 4
br label %loop_body

23: ; preds = %cond_exit_12
%24 = extractvalue { i32, { i64 }, { { {} } } } %"12_034", 2
%25 = extractvalue { { {} } } %24, 0
store { {} } %25, { {} }* %"4_0", align 1
store i64 %"12_135", i64* %"4_1", align 4
br label %loop_out
}
Loading

0 comments on commit 4bad836

Please sign in to comment.