Skip to content

Commit

Permalink
WIP: Tail loop emission
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Dec 10, 2024
1 parent 37d4e51 commit 7d733eb
Showing 1 changed file with 72 additions and 6 deletions.
78 changes: 72 additions & 6 deletions hugr-llvm/src/emit/ops.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use anyhow::{anyhow, bail, Result};
use hugr_core::ops::{
constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant,
LoadFunction, OpTag, OpTrait, OpType, Output, Tag, Value, CFG,
LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG,
};
use hugr_core::{
hugr::views::SiblingGraph,
types::{SumType, Type, TypeEnum},
HugrView, NodeIndex,
};
use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValueEnum, CallableValue};
use inkwell::types::{BasicTypeEnum, IntType};
use inkwell::values::{BasicValueEnum, CallableValue, IntValue};
use itertools::{zip_eq, Itertools};
use petgraph::visit::Walker;

Expand All @@ -21,7 +21,7 @@ use crate::{

use super::{
deaggregate_call_result,
func::{EmitFuncContext, RowPromise},
func::{EmitFuncContext, RowMailBox, RowPromise},
EmitOpArgs,
};

Expand All @@ -31,6 +31,7 @@ struct DataflowParentEmitter<'c, 'hugr, OT, H> {
node: FatNode<'hugr, OT, H>,
inputs: Option<Vec<BasicValueEnum<'c>>>,
outputs: Option<RowPromise<'c>>,
output_vals: Option<Vec<BasicValueEnum<'c>>>,
}

impl<'c, 'hugr, OT: OpTrait, H: HugrView> DataflowParentEmitter<'c, 'hugr, OT, H>
Expand All @@ -42,6 +43,7 @@ where
node: args.node,
inputs: Some(args.inputs),
outputs: Some(args.outputs),
output_vals: None,
}
}

Expand All @@ -58,7 +60,7 @@ where
.ok_or(anyhow!("DataflowParentEmitter: Output taken twice"))
}

pub fn emit_children(mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> {
pub fn emit_children(&mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> {
use petgraph::visit::Topo;
let node = self.node;
if !OpTag::DataflowParent.is_superset(node.tag()) {
Expand Down Expand Up @@ -306,6 +308,70 @@ fn emit_cfg<'c, H: HugrView>(
cfg::CfgEmitter::new(context, args)?.emit_children(context)
}

fn emit_tail_loop<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, TailLoop, H>
) -> Result<()> {
// TODO: Switch on the tag in loop_body to see where to go next
// TODO: Handle "other" args


// Make a block to jump to when we `Break`
let out_bb = context.new_basic_block("loop_out", None);
// A block for the body of the loop
let body_bb = context.new_basic_block("loop_body", Some(out_bb));
// Pack input data into a sum type - do we need this?
let prep_bb = context.new_basic_block("loop_prep", Some(body_bb));

context.builder().build_unconditional_branch(prep_bb);

let sum_ty = SumType::new([args.node().just_inputs.clone(), args.node().just_outputs.clone()]);
let outs_rmb = context.node_outs_rmb(args.node)?;

{
let builder = context.builder();
builder.position_at_end(prep_bb);
let body_in_row = args.node.just_inputs.clone();
let body_in_len = body_in_row.len();
let body_in_tuple = context.llvm_sum_type(SumType::new_tuple(body_in_row))?;

let mut loop_inputs = args.inputs.clone();
let other_inputs = loop_inputs.split_off(body_in_len);
let loop_input_ptr = builder.build_alloca(body_in_tuple.clone(), "loop_input")?;

let body_in_tup = body_in_tuple.build_tag(builder, 0, loop_inputs)?;
builder.build_store(loop_input_ptr, body_in_tup);
builder.build_unconditional_branch(body_bb);

builder.position_at_end(body_bb);
};


// Emit the body of the loop into the right block
let mut dfpe = DataflowParentEmitter::new(args);
dfpe.emit_children(context)?;

// After the body we need to unpack the row type, then jump to the right block
let builder = context.builder();
let output_vals: Vec<BasicValueEnum> = outs_rmb.read(builder, []).unwrap();
let output_types: Vec<_> = outs_rmb.get_types().collect();
let llvm_sum_ty = LLVMSumType::try_new(&context.typing_session(), sum_ty)?;

println!("{:?}", output_vals);
let sum_output = LLVMSumValue::try_new(output_vals[0], llvm_sum_ty)?;
let tag = sum_output.build_get_tag(builder)?;

let tag = IntValue::try_from(output_vals[0]).unwrap();
let continue_tag = context.iw_context().i64_type().const_int(0, false);
let break_tag = context.iw_context().i64_type().const_int(1, false);
// TODO: Make this a conditional branch instead of switch
builder.build_switch(tag, out_bb, &[(break_tag, out_bb), (continue_tag, prep_bb)]);
du
// Return Ok so we can see the insta emission with
// `cargo insta test` for debugging
Ok(())
}

fn emit_optype<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, OpType, H>,
Expand All @@ -330,7 +396,7 @@ fn emit_optype<'c, H: HugrView>(
context.push_todo_func(node.into_ot(fd));
Ok(())
}

OpType::TailLoop(x) => emit_tail_loop(context, args.into_ot(x)),
_ => Err(anyhow!("Invalid child for Dataflow Parent: {node}")),
}
}
Expand Down

0 comments on commit 7d733eb

Please sign in to comment.