From ab265843e3aed68c43ac9452795d9f2152cb31b3 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Tue, 11 Jun 2024 15:34:05 +0100 Subject: [PATCH] feat: lower CFGs --- .gitignore | 1 + Cargo.lock | 44 ++- Cargo.toml | 2 +- src/custom/int.rs | 4 +- src/emit.rs | 5 +- src/emit/func/mailbox.rs | 4 + src/emit/ops.rs | 14 +- src/emit/ops/cfg.rs | 290 ++++++++++++++++++ ...emit__ops__cfg__test__emit_cfg@llvm14.snap | 49 +++ ...fg__test__emit_cfg@pre-mem2reg@llvm14.snap | 96 ++++++ src/emit/test.rs | 16 +- src/fat.rs | 52 +++- src/types.rs | 16 +- 13 files changed, 553 insertions(+), 40 deletions(-) create mode 100644 src/emit/ops/cfg.rs create mode 100644 src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@llvm14.snap create mode 100644 src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@pre-mem2reg@llvm14.snap diff --git a/.gitignore b/.gitignore index 1eb22f5..41d393a 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ devenv.local.nix /target /.envrc /result +*.snap.new diff --git a/Cargo.lock b/Cargo.lock index a83ba85..b3e60b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -298,9 +298,19 @@ dependencies = [ [[package]] name = "hugr" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8e15eaffd64f1cceac13429e5ceaf20017691e647cc56b8b7c53d73ad9f8714" +checksum = "a20246e5f1a0aae160b80b71bc4d5c1dcccc5605dbe6458ac6f9af73137339c3" +dependencies = [ + "hugr-core", + "hugr-passes", +] + +[[package]] +name = "hugr-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd64c708ba21cedae58e37f3557d139ee6eaeaaa7df0ab9541385249c1ee6065" dependencies = [ "bitvec", "cgmath", @@ -310,7 +320,7 @@ dependencies = [ "downcast-rs", "enum_dispatch", "html-escape", - "itertools", + "itertools 0.13.0", "lazy_static", "num-rational", "paste", @@ -337,13 +347,26 @@ dependencies = [ "hugr", "inkwell", "insta", - "itertools", + "itertools 0.12.1", "lazy_static", "llvm-sys", "petgraph", "rstest", ] +[[package]] +name = "hugr-passes" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7522d3a5299a49ee453f6d0e3c3a6e8e7ed08a38cbd11e1d19163b97098ce9be" +dependencies = [ + "hugr-core", + "itertools 0.13.0", + "lazy_static", + "paste", + "thiserror", +] + [[package]] name = "indexmap" version = "2.2.6" @@ -406,6 +429,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -525,9 +557,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "portgraph" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8ad1ebc029f8dfab4f023f14c7e41b3b3a257cf28b2928949153e563c9bf02c" +checksum = "0d83271bd5c2249831ff13227e6b3a968e115b6d4f4a2027a8ce02838b789ff8" dependencies = [ "bitvec", "context-iterators", diff --git a/Cargo.toml b/Cargo.toml index c58118a..c4ec546 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ llvm14 = ["dep:llvm-sys-140", "inkwell/llvm14-0"] [dependencies] inkwell = { version = "0.4.0", default-features=false } llvm-sys-140 = { package = "llvm-sys", version = "140.1.3", features = ["prefer-static"], optional = true} -hugr = "0.4.0" +hugr = "0.5.1" anyhow = "1.0.83" itertools = "0.12.1" delegate = "0.12.0" diff --git a/src/custom/int.rs b/src/custom/int.rs index fef9283..aca603d 100644 --- a/src/custom/int.rs +++ b/src/custom/int.rs @@ -4,7 +4,7 @@ use hugr::{ extension::{simple_op::MakeExtensionOp, ExtensionId}, ops::{constant::CustomConst, CustomOp, NamedOp}, std_extensions::arithmetic::{ - int_ops::{self, IntOpType}, + int_ops::{self, ConcreteIntOp}, int_types::{self, ConstInt}, }, types::{CustomType, TypeArg}, @@ -27,7 +27,7 @@ struct IntOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>); impl<'c, H: HugrView> EmitOp<'c, CustomOp, H> for IntOpEmitter<'c, '_, H> { fn emit(&mut self, args: EmitOpArgs<'c, CustomOp, H>) -> Result<()> { - let iot = IntOpType::from_optype(&args.node().generalise()) + let iot = ConcreteIntOp::from_optype(&args.node().generalise()) .ok_or(anyhow!("IntOpEmitter from_optype_failed"))?; match iot.name().as_str() { "iadd" => { diff --git a/src/emit.rs b/src/emit.rs index 63f7b26..da0b4b9 100644 --- a/src/emit.rs +++ b/src/emit.rs @@ -70,10 +70,9 @@ impl<'c, H: HugrView> EmitOpArgs<'c, OpType, H> { /// Specialise the internal [FatNode]. /// /// Panics if `OT` is not the `get_optype` of the internal [Node]. - pub fn into_ot(self, ot: &'c OTInto) -> EmitOpArgs<'c, OTInto, H> + pub fn into_ot<'b, OTInto: PartialEq + 'c>(self, ot: &'b OTInto) -> EmitOpArgs<'c, OTInto, H> where - &'c OpType: TryInto<&'c OTInto>, - <&'c OpType as TryInto<&'c OTInto>>::Error: std::fmt::Debug, + for<'a> &'a OpType: TryInto<&'a OTInto>, { let EmitOpArgs { node, diff --git a/src/emit/func/mailbox.rs b/src/emit/func/mailbox.rs index 06bb727..a5c8b2e 100644 --- a/src/emit/func/mailbox.rs +++ b/src/emit/func/mailbox.rs @@ -93,6 +93,10 @@ impl<'c> ValuePromise<'c> { pub struct RowMailBox<'c>(Rc>>, Cow<'static, str>); impl<'c> RowMailBox<'c> { + pub fn new_empty() -> Self { + Self::new(std::iter::empty(), None) + } + pub(super) fn new( mbs: impl IntoIterator>, name: Option, diff --git a/src/emit/ops.rs b/src/emit/ops.rs index ea51b7d..734adcb 100644 --- a/src/emit/ops.rs +++ b/src/emit/ops.rs @@ -3,7 +3,7 @@ use hugr::{ hugr::views::SiblingGraph, ops::{ Call, Case, Conditional, Const, Input, LoadConstant, MakeTuple, NamedOp, OpTag, OpTrait, - OpType, Output, Tag, UnpackTuple, Value, + OpType, Output, Tag, UnpackTuple, Value, CFG, }, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, @@ -20,6 +20,8 @@ use super::{ EmitOp, EmitOpArgs, }; +mod cfg; + struct SumOpEmitter<'c, 'd, H: HugrView>(&'d mut EmitFuncContext<'c, H>, LLVMSumType<'c>); impl<'c, 'd, H: HugrView> SumOpEmitter<'c, 'd, H> { @@ -228,7 +230,7 @@ impl<'c, H: HugrView> EmitOp<'c, Conditional, H> for ConditionalEmitter<'c, '_, }) .collect::>>()?; - builder.build_switch(tag.into_int_value(), switches[0].1, &switches[1..])?; + builder.build_switch(tag, switches[0].1, &switches[1..])?; builder.position_at_end(exit_block); Ok(()) } @@ -370,6 +372,13 @@ fn emit_call<'c, H: HugrView>( args.outputs.finish(builder, rets) } +fn emit_cfg<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, H>, + args: EmitOpArgs<'c, CFG, H>, +) -> Result<()> { + cfg::CfgEmitter::new(context, args)?.emit_children() +} + fn emit_optype<'c, H: HugrView>( context: &mut EmitFuncContext<'c, H>, args: EmitOpArgs<'c, OpType, H>, @@ -390,6 +399,7 @@ fn emit_optype<'c, H: HugrView>( OpType::LoadConstant(ref lc) => emit_load_constant(context, args.into_ot(lc)), OpType::Call(ref cl) => emit_call(context, args.into_ot(cl)), OpType::Conditional(ref co) => emit_conditional(context, args.into_ot(co)), + OpType::CFG(ref cfg) => emit_cfg(context, args.into_ot(cfg)), // OpType::FuncDefn(fd) => self.emit(ot.into_ot(fd), context, inputs, outputs), _ => todo!("Unimplemented OpTypeEmitter: {}", args.node().name()), diff --git a/src/emit/ops/cfg.rs b/src/emit/ops/cfg.rs new file mode 100644 index 0000000..4606cb0 --- /dev/null +++ b/src/emit/ops/cfg.rs @@ -0,0 +1,290 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, Result}; +use hugr::{ + ops::{DataflowBlock, ExitBlock, OpType, CFG}, + types::SumType, + HugrView, NodeIndex, +}; +use inkwell::{basic_block::BasicBlock, values::BasicValueEnum}; +use itertools::Itertools as _; + +use crate::{ + emit::{ + func::{EmitFuncContext, RowMailBox, RowPromise}, + EmitOp, EmitOpArgs, + }, + fat::FatNode, +}; + +use super::emit_dataflow_parent; + +pub struct CfgEmitter<'c, 'd, H: HugrView> { + context: &'d mut EmitFuncContext<'c, H>, + bbs: HashMap, (BasicBlock<'c>, RowMailBox<'c>)>, + inputs: Option>>, + outputs: Option>, + node: FatNode<'c, CFG, H>, + entry_node: FatNode<'c, DataflowBlock, H>, + exit_node: FatNode<'c, ExitBlock, H>, +} + +impl<'c, 'd, H: HugrView> CfgEmitter<'c, 'd, H> { + // Constructs a new CfgEmitter. Creates a basic block for each of + // the children in the llvm function. Note that this does not move the + // position of the builder. + pub fn new( + context: &'d mut EmitFuncContext<'c, H>, + args: EmitOpArgs<'c, CFG, H>, + ) -> Result { + let node = args.node(); + let (inputs, outputs) = (Some(args.inputs), Some(args.outputs)); + + // create this now so that it will be the last block and we can use it + // to crate the other blocks immediately before it. This is just for + // nice block ordering. + let exit_block = context.new_basic_block("", None); + let bbs = node + .children() + .map(|child| { + if child.is_exit_block() { + let output_row = { + let out_types = node.out_value_types().map(|x| x.1).collect_vec(); + context.new_row_mail_box(out_types.iter(), "")? + }; + Ok((child, (exit_block, output_row))) + } else { + let bb = context.new_basic_block("", Some(exit_block)); + let (i, _) = child.get_io().unwrap(); + Ok((child, (bb, context.node_outs_rmb(i)?))) + } + }) + .collect::>>()?; + let (entry_node, exit_node) = node.get_entry_exit().unwrap(); + Ok(CfgEmitter { + context, + bbs, + node, + inputs, + outputs, + entry_node, + exit_node, + }) + } + + fn take_inputs(&mut self) -> Result>> { + self.inputs.take().ok_or(anyhow!("Couldn't take inputs")) + } + + fn take_outputs(&mut self) -> Result> { + self.outputs.take().ok_or(anyhow!("Couldn't take inputs")) + } + + fn get_block_data( + &self, + node: &FatNode<'c, OT, H>, + ) -> Result<&(BasicBlock<'c>, RowMailBox<'c>)> + where + for<'a> &'a OpType: TryInto<&'a OT>, + { + self.bbs + .get(&node.clone().generalise()) + .ok_or(anyhow!("Couldn't get block data for: {}", node.index())) + } + + /// Consume the emitter by emitting each child of the node. + /// After returning the builder will be at the end of the exit block. + pub fn emit_children(mut self) -> Result<()> { + // write the inputs of the cfg node into the inputs of the entry + // dataflowblock node, and then branch to the basic block of that entry + // node. + let inputs = self.take_inputs()?; + let (entry_bb, inputs_rmb) = self.get_block_data(&self.entry_node).cloned()?; + let builder = self.context.builder(); + inputs_rmb.write(builder, inputs)?; + builder.build_unconditional_branch(entry_bb)?; + + // emit each child by delegating to the `impl EmitOp<_>` of self. + for c in self.node.children() { + let (inputs, outputs) = (vec![], RowMailBox::new_empty().promise()); + if let Some(node) = c.try_into_ot::() { + self.emit(EmitOpArgs { + node, + inputs, + outputs, + })?; + } else if let Some(node) = c.try_into_ot::() { + self.emit(EmitOpArgs { + node, + inputs, + outputs, + })?; + } else { + Err(anyhow!("unknown optype: {c}"))?; + } + } + + // move the builder to the end of the exit block + let (exit_bb, _) = self.get_block_data(&self.exit_node).cloned()?; + self.context.builder().position_at_end(exit_bb); + Ok(()) + } +} + +impl<'c, H: HugrView> EmitOp<'c, OpType, H> for CfgEmitter<'c, '_, H> { + fn emit(&mut self, args: EmitOpArgs<'c, OpType, H>) -> Result<()> { + match args.node().as_ref() { + OpType::DataflowBlock(ref dfb) => self.emit(args.into_ot(dfb)), + OpType::ExitBlock(ref eb) => self.emit(args.into_ot(eb)), + ot => Err(anyhow!("unknown optype: {ot:?}")), + } + } +} +impl<'c, H: HugrView> EmitOp<'c, DataflowBlock, H> for CfgEmitter<'c, '_, H> { + fn emit( + &mut self, + EmitOpArgs { + node, + inputs: _, + outputs: _, + }: EmitOpArgs<'c, DataflowBlock, H>, + ) -> Result<()> { + // our entry basic block and our input RowMailBox + let (bb, inputs_rmb) = self.bbs.get(&node.clone().generalise()).unwrap(); + // the basic block and mailbox of each of our successors + let successor_data = node + .output_neighbours() + .map(|succ| self.get_block_data(&succ).map(|x| x.clone())) + .collect::>>()?; + + self.context.build_positioned(*bb, |context| { + let (_, o) = node.get_io().unwrap(); + // get the rowmailbox for our output node + let outputs_rmb = context.node_ins_rmb(o)?; + // read the values from our input node + let inputs = inputs_rmb.read_vec(context.builder(), [])?; + + // emit all our children and read the values from the rowmailbox of our output node + emit_dataflow_parent( + context, + EmitOpArgs { + node: node.clone(), + inputs, + outputs: outputs_rmb.promise(), + }, + )?; + let outputs = outputs_rmb.read_vec(context.builder(), [])?; + + let branch_sum_type = SumType::new(node.sum_rows.clone()); + let llvm_sum_type = context.llvm_sum_type(branch_sum_type)?; + let tag_bbs = successor_data + .into_iter() + .enumerate() + .map(|(tag, (target_bb, target_rmb))| { + let bb = context.build_positioned_new_block("", Some(*bb), |context, bb| { + let builder = context.builder(); + let mut vals = + llvm_sum_type.build_untag(builder, tag as u32, outputs[0])?; + vals.extend(&outputs[1..]); + target_rmb.write(builder, vals)?; + builder.build_unconditional_branch(target_bb)?; + Ok::<_, anyhow::Error>(bb) + })?; + Ok(( + llvm_sum_type.get_tag_type().const_int(tag as u64, false), + bb, + )) + }) + .collect::>>()?; + let tag_v = llvm_sum_type.build_get_tag(context.builder(), outputs[0])?; + context + .builder() + .build_switch(tag_v, tag_bbs[0].1, &tag_bbs[1..])?; + Ok(()) + }) + } +} + +impl<'c, H: HugrView> EmitOp<'c, ExitBlock, H> for CfgEmitter<'c, '_, H> { + fn emit(&mut self, args: EmitOpArgs<'c, ExitBlock, H>) -> Result<()> { + let outputs = self.take_outputs()?; + let (bb, inputs_rmb) = self.bbs.get(&args.node().generalise()).unwrap(); + self.context.build_positioned(*bb, |context| { + let builder = context.builder(); + outputs.finish(builder, inputs_rmb.read_vec(builder, [])?) + }) + } +} + +#[cfg(test)] +mod test { + use hugr::builder::{Dataflow, DataflowSubContainer, SubContainer}; + use hugr::extension::{ExtensionRegistry, ExtensionSet}; + use hugr::std_extensions::arithmetic::int_types::{self, INT_TYPES}; + use hugr::type_row; + + use rstest::rstest; + + use crate::custom::int::add_int_extensions; + use crate::emit::test::SimpleHugrConfig; + use crate::test::{llvm_ctx, TestContext}; + + use crate::check_emission; + + #[rstest] + fn emit_cfg(#[with(-1, add_int_extensions)] llvm_ctx: TestContext) { + let t1 = INT_TYPES[0].clone(); + let t2 = INT_TYPES[1].clone(); + let es = ExtensionSet::singleton(&int_types::EXTENSION_ID); + let hugr = SimpleHugrConfig::new() + .with_ins(vec![t1.clone(), t2.clone()]) + .with_outs(t2.clone()) + .with_extensions(ExtensionRegistry::try_new([int_types::extension()]).unwrap()) + .finish(|mut builder| { + let [in1, in2] = builder.input_wires_arr(); + let mut cfg_builder = builder + .cfg_builder( + [(t1.clone(), in1), (t2.clone(), in2)], + None, + t2.clone().into(), + es.clone(), + ) + .unwrap(); + + // entry block takes (t1,t2) and unconditionally branches to b1 with no other outputs + let mut entry_builder = cfg_builder + .entry_builder( + [vec![t1.clone(), t2.clone()].into()], + type_row![], + es.clone(), + ) + .unwrap(); + let [entry_in1, entry_in2] = entry_builder.input_wires_arr(); + let r = entry_builder.make_tuple([entry_in1, entry_in2]).unwrap(); + let entry_block = entry_builder.finish_with_outputs(r, []).unwrap(); + + // b1 takes (t1,t2) and branches to either entry or exit, with sum type [(t1) + ()] and other outputs [t2] + let variants = vec![t1.clone().into(), type_row![]]; + let mut b1_builder = cfg_builder + .block_builder( + vec![t1.clone(), t2.clone()].into(), + variants.clone(), + es, + t2.clone().into(), + ) + .unwrap(); + let [b1_in1, b1_in2] = b1_builder.input_wires_arr(); + let r = b1_builder.make_sum(0, variants, [b1_in1]).unwrap(); + let b1 = b1_builder.finish_with_outputs(r, [b1_in2]).unwrap(); + + let exit_block = cfg_builder.exit_block(); + cfg_builder.branch(&entry_block, 0, &b1).unwrap(); + cfg_builder.branch(&b1, 0, &entry_block).unwrap(); + cfg_builder.branch(&b1, 1, &exit_block).unwrap(); + let cfg = cfg_builder.finish_sub_container().unwrap(); + let [cfg_out] = cfg.outputs_arr(); + builder.finish_with_outputs([cfg_out]).unwrap() + }); + check_emission!(hugr, llvm_ctx); + } +} diff --git a/src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@llvm14.snap b/src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@llvm14.snap new file mode 100644 index 0000000..f4df285 --- /dev/null +++ b/src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@llvm14.snap @@ -0,0 +1,49 @@ +--- +source: src/emit/ops/cfg.rs +expression: module.to_string() +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8 @_hl.main.1(i8 %0, i8 %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + br label %6 + +2: ; preds = %6 + %3 = extractvalue { { i8, i8 } } %9, 0 + %4 = extractvalue { i8, i8 } %3, 0 + %5 = extractvalue { i8, i8 } %3, 1 + br label %15 + +6: ; preds = %10, %entry_block + %"7_0.0" = phi i8 [ %0, %entry_block ], [ %12, %10 ] + %"7_1.0" = phi i8 [ %1, %entry_block ], [ %5, %10 ] + %7 = insertvalue { i8, i8 } undef, i8 %"7_0.0", 0 + %8 = insertvalue { i8, i8 } %7, i8 %"7_1.0", 1 + %9 = insertvalue { { i8, i8 } } poison, { i8, i8 } %8, 0 + switch i32 0, label %2 [ + ] + +10: ; preds = %15 + %11 = extractvalue { i32, { i8 }, {} } %17, 1 + %12 = extractvalue { i8 } %11, 0 + br label %6 + +13: ; preds = %15 + %14 = extractvalue { i32, { i8 }, {} } %17, 2 + br label %19 + +15: ; preds = %2 + %16 = insertvalue { i8 } undef, i8 %4, 0 + %17 = insertvalue { i32, { i8 }, {} } { i32 0, { i8 } poison, {} poison }, { i8 } %16, 1 + %18 = extractvalue { i32, { i8 }, {} } %17, 0 + switch i32 %18, label %10 [ + i32 1, label %13 + ] + +19: ; preds = %13 + ret i8 %5 +} diff --git a/src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@pre-mem2reg@llvm14.snap b/src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@pre-mem2reg@llvm14.snap new file mode 100644 index 0000000..f2fccbc --- /dev/null +++ b/src/emit/ops/snapshots/hugr_llvm__emit__ops__cfg__test__emit_cfg@pre-mem2reg@llvm14.snap @@ -0,0 +1,96 @@ +--- +source: src/emit/ops/cfg.rs +expression: module.to_string() +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8 @_hl.main.1(i8 %0, i8 %1) { +alloca_block: + %"0" = alloca i8, align 1 + %"2_0" = alloca i8, align 1 + %"2_1" = alloca i8, align 1 + %"4_0" = alloca i8, align 1 + %"7_0" = alloca i8, align 1 + %"7_1" = alloca i8, align 1 + %"03" = alloca i8, align 1 + %"11_0" = alloca i8, align 1 + %"11_1" = alloca i8, align 1 + %"9_0" = alloca { { i8, i8 } }, align 8 + %"13_0" = alloca { i32, { i8 }, {} }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8 %0, i8* %"2_0", align 1 + store i8 %1, i8* %"2_1", align 1 + %"2_01" = load i8, i8* %"2_0", align 1 + %"2_12" = load i8, i8* %"2_1", align 1 + store i8 %"2_01", i8* %"7_0", align 1 + store i8 %"2_12", i8* %"7_1", align 1 + br label %6 + +2: ; preds = %6 + %3 = extractvalue { { i8, i8 } } %"9_09", 0 + %4 = extractvalue { i8, i8 } %3, 0 + %5 = extractvalue { i8, i8 } %3, 1 + store i8 %4, i8* %"11_0", align 1 + store i8 %5, i8* %"11_1", align 1 + br label %15 + +6: ; preds = %10, %entry_block + %"7_04" = load i8, i8* %"7_0", align 1 + %"7_15" = load i8, i8* %"7_1", align 1 + store i8 %"7_04", i8* %"7_0", align 1 + store i8 %"7_15", i8* %"7_1", align 1 + %"7_06" = load i8, i8* %"7_0", align 1 + %"7_17" = load i8, i8* %"7_1", align 1 + %7 = insertvalue { i8, i8 } undef, i8 %"7_06", 0 + %8 = insertvalue { i8, i8 } %7, i8 %"7_17", 1 + %9 = insertvalue { { i8, i8 } } poison, { i8, i8 } %8, 0 + store { { i8, i8 } } %9, { { i8, i8 } }* %"9_0", align 1 + %"9_08" = load { { i8, i8 } }, { { i8, i8 } }* %"9_0", align 1 + store { { i8, i8 } } %"9_08", { { i8, i8 } }* %"9_0", align 1 + %"9_09" = load { { i8, i8 } }, { { i8, i8 } }* %"9_0", align 1 + switch i32 0, label %2 [ + ] + +10: ; preds = %15 + %11 = extractvalue { i32, { i8 }, {} } %"13_016", 1 + %12 = extractvalue { i8 } %11, 0 + store i8 %12, i8* %"7_0", align 1 + store i8 %"11_117", i8* %"7_1", align 1 + br label %6 + +13: ; preds = %15 + %14 = extractvalue { i32, { i8 }, {} } %"13_016", 2 + store i8 %"11_117", i8* %"03", align 1 + br label %19 + +15: ; preds = %2 + %"11_011" = load i8, i8* %"11_0", align 1 + %"11_112" = load i8, i8* %"11_1", align 1 + store i8 %"11_011", i8* %"11_0", align 1 + store i8 %"11_112", i8* %"11_1", align 1 + %"11_013" = load i8, i8* %"11_0", align 1 + %16 = insertvalue { i8 } undef, i8 %"11_013", 0 + %17 = insertvalue { i32, { i8 }, {} } { i32 0, { i8 } poison, {} poison }, { i8 } %16, 1 + store { i32, { i8 }, {} } %17, { i32, { i8 }, {} }* %"13_0", align 4 + %"13_014" = load { i32, { i8 }, {} }, { i32, { i8 }, {} }* %"13_0", align 4 + %"11_115" = load i8, i8* %"11_1", align 1 + store { i32, { i8 }, {} } %"13_014", { i32, { i8 }, {} }* %"13_0", align 4 + store i8 %"11_115", i8* %"11_1", align 1 + %"13_016" = load { i32, { i8 }, {} }, { i32, { i8 }, {} }* %"13_0", align 4 + %"11_117" = load i8, i8* %"11_1", align 1 + %18 = extractvalue { i32, { i8 }, {} } %"13_016", 0 + switch i32 %18, label %10 [ + i32 1, label %13 + ] + +19: ; preds = %13 + %"010" = load i8, i8* %"03", align 1 + store i8 %"010", i8* %"4_0", align 1 + %"4_018" = load i8, i8* %"4_0", align 1 + store i8 %"4_018", i8* %"0", align 1 + %"019" = load i8, i8* %"0", align 1 + ret i8 %"019" +} diff --git a/src/emit/test.rs b/src/emit/test.rs index ce82689..8c57193 100644 --- a/src/emit/test.rs +++ b/src/emit/test.rs @@ -1,5 +1,4 @@ use crate::custom::int::add_int_extensions; -use crate::fat::FatExt as _; use hugr::builder::{ BuildHandle, Container, DFGWrapper, Dataflow, HugrBuilder, ModuleBuilder, SubContainer, }; @@ -7,14 +6,12 @@ use hugr::extension::prelude::BOOL_T; use hugr::extension::{ExtensionRegistry, ExtensionSet, EMPTY_REG}; use hugr::ops::constant::CustomConst; use hugr::ops::handle::FuncID; -use hugr::ops::{Module, Tag, UnpackTuple, Value}; +use hugr::ops::{Tag, UnpackTuple, Value}; use hugr::std_extensions::arithmetic::int_ops::{self, INT_OPS_REGISTRY}; use hugr::std_extensions::arithmetic::int_types::ConstInt; use hugr::types::{Type, TypeRow}; use hugr::{builder::DataflowSubContainer, types::FunctionType}; use hugr::{type_row, Hugr}; -use inkwell::passes::PassManager; -use insta::assert_snapshot; use itertools::Itertools; use rstest::rstest; @@ -23,7 +20,7 @@ use crate::test::*; #[allow(clippy::upper_case_acronyms)] type DFGW<'a> = DFGWrapper<&'a mut Hugr, BuildHandle>>; -struct SimpleHugrConfig { +pub struct SimpleHugrConfig { ins: TypeRow, outs: TypeRow, extensions: ExtensionRegistry, @@ -75,9 +72,10 @@ impl SimpleHugrConfig { } } +#[macro_export] macro_rules! check_emission { ($hugr: ident, $test_ctx:ident) => { - let root = $hugr.fat_root::().unwrap(); + let root = crate::fat::FatExt::fat_root::(&$hugr).unwrap(); let (_, module) = $test_ctx.with_emit_context(|ec| ((), ec.emit_module(root).unwrap())); let mut settings = insta::Settings::clone_current(); @@ -85,17 +83,17 @@ macro_rules! check_emission { .snapshot_suffix() .map_or("pre-mem2reg".into(), |x| format!("pre-mem2reg@{x}")); settings.set_snapshot_suffix(new_suffix); - settings.bind(|| assert_snapshot!(module.to_string())); + settings.bind(|| insta::assert_snapshot!(module.to_string())); module .verify() .unwrap_or_else(|pp| panic!("Failed to verify module: {pp}")); - let pb = PassManager::create(()); + let pb = inkwell::passes::PassManager::create(()); pb.add_promote_memory_to_register_pass(); pb.run_on(&module); - assert_snapshot!(module.to_string()); + insta::assert_snapshot!(module.to_string()); }; } diff --git a/src/fat.rs b/src/fat.rs index b9971d1..fff3a35 100644 --- a/src/fat.rs +++ b/src/fat.rs @@ -5,10 +5,11 @@ use std::{cmp::Ordering, hash::Hash, marker::PhantomData, ops::Deref}; use hugr::{ - ops::{Input, NamedOp, OpType, Output}, + ops::{DataflowBlock, ExitBlock, Input, NamedOp, OpType, Output, CFG}, types::Type, Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, }; +use itertools::Itertools as _; /// A Fat Node is a [Node] along with a reference to the [HugrView] whence it /// originates. It carries a type `OT`, the [OpType] of that node. `OT` may be @@ -44,7 +45,7 @@ where /// /// Note that while we do check that the type of the node's `get_optype`, we do /// not verify that it is actually equal to `ot`. - pub fn new(hugr: &'c H, node: Node, #[allow(unused)] ot: &'c OT) -> Self { + pub fn new(hugr: &'c H, node: Node, #[allow(unused)] ot: &OT) -> Self { assert!(hugr.valid_node(node)); assert!(TryInto::<&'c OT>::try_into(hugr.get_optype(node)).is_ok()); // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq` @@ -106,9 +107,9 @@ impl<'c, H: HugrView + ?Sized> FatNode<'c, OpType, H> { // Creates a specific `FatNode` from a general `FatNode`. // // Panics if the node's `get_optype` is not `OT`. - pub fn into_ot(self, ot: &'c OT) -> FatNode<'c, OT, H> + pub fn into_ot(self, ot: &OT) -> FatNode<'c, OT, H> where - &'c OpType: TryInto<&'c OT>, + for<'a> &'a OpType: TryInto<&'a OT>, { FatNode::new(self.hugr, self.node, ot) } @@ -155,10 +156,20 @@ impl<'c, OT, H: HugrView + ?Sized> FatNode<'c, OT, H> { )) } + pub fn node_outputs(&self) -> impl Iterator + '_ { + self.hugr.node_outputs(self.node) + } + + pub fn output_neighbours(&self) -> impl Iterator> + '_ { + self.hugr + .output_neighbours(self.node) + .map(|n| FatNode::new_optype(self.hugr, n)) + } + /// Create a general `FatNode` from a specific one. pub fn generalise(self) -> FatNode<'c, OpType, H> where - &'c OpType: TryInto<&'c OT>, + for<'a> &'a OpType: TryInto<&'a OT>, OT: 'c, { // guaranteed to be valid becasue self is valid @@ -170,6 +181,25 @@ impl<'c, OT, H: HugrView + ?Sized> FatNode<'c, OT, H> { } } +impl<'c, H: HugrView> FatNode<'c, CFG, H> { + /// TODO it would be reasonable to remove Option and panic on failure here + pub fn get_entry_exit( + &self, + ) -> Option<(FatNode<'c, DataflowBlock, H>, FatNode<'c, ExitBlock, H>)> { + let [i, o] = self + .hugr + .children(self.node) + .take(2) + .collect_vec() + .try_into() + .ok()?; + Some(( + FatNode::try_new(self.hugr, i)?, + FatNode::try_new(self.hugr, o)?, + )) + } +} + impl<'c, OT, H> PartialEq for FatNode<'c, OT, H> { fn eq(&self, other: &Node) -> bool { &self.node == other @@ -220,11 +250,9 @@ impl<'c, OT, H> Hash for FatNode<'c, OT, H> { } } -impl<'c, OT, H: HugrView + ?Sized> AsRef for FatNode<'c, OT, H> +impl<'c, OT: 'c, H: HugrView + ?Sized> AsRef for FatNode<'c, OT, H> where - &'c OpType: TryInto<&'c OT>, - <&'c OpType as TryInto<&'c OT>>::Error: std::fmt::Debug, - OT: 'c, + for<'a> &'a OpType: TryInto<&'a OT>, { fn as_ref(&self) -> &OT { self.get() @@ -271,6 +299,12 @@ impl<'c, OT, H> NodeIndex for FatNode<'c, OT, H> { } } +impl<'c, OT, H> NodeIndex for &FatNode<'c, OT, H> { + fn index(self) -> usize { + self.node.index() + } +} + /// An extension trait for [HugrView] which provides methods that delegate to /// [HugrView] and then return the result in [FatNode] form. See for example /// [FatExt::fat_io]. diff --git a/src/types.rs b/src/types.rs index df2fe71..e2829c8 100644 --- a/src/types.rs +++ b/src/types.rs @@ -8,7 +8,7 @@ use hugr::types::SumType; use hugr::{types::TypeRow, HugrView}; use inkwell::builder::Builder; use inkwell::types::{self as iw, AnyType, AsTypeRef, IntType}; -use inkwell::values::{BasicValue, BasicValueEnum, StructValue}; +use inkwell::values::{BasicValue, BasicValueEnum, IntValue, StructValue}; use inkwell::{ context::Context, types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum, StructType}, @@ -52,13 +52,11 @@ impl<'c, H: HugrView> TypingSession<'c, H> { use hugr::types::TypeEnum; match hugr_type.as_type_enum() { TypeEnum::Extension(ref custom_type) => self.extensions.llvm_type(self, custom_type), - TypeEnum::Alias(ref alias) => Err(anyhow!("Invalid type: {:?}", alias)), - + TypeEnum::Sum(sum) => self.llvm_sum_type(sum.clone()).map(Into::into), // TODO Function Types are fine TypeEnum::Function(ref func_ty) => Err(anyhow!("Invalid type: {:?}", func_ty)), - x @ TypeEnum::Variable(_, _) => Err(anyhow!("Invalid type: {:?}", x)), - TypeEnum::Sum(sum) => self.llvm_sum_type(sum.clone()).map(Into::into), + x => Err(anyhow!("Invalid type: {:?}", x)), } } @@ -199,15 +197,17 @@ impl<'c> LLVMSumType<'c> { &self, builder: &Builder<'c>, v: impl BasicValue<'c>, - ) -> Result> { + ) -> Result> { let struct_value: StructValue<'c> = v .as_basic_value_enum() .try_into() .map_err(|_| anyhow!("Not a struct type"))?; if self.has_tag_field() { - Ok(builder.build_extract_value(struct_value, 0, "")?) + Ok(builder + .build_extract_value(struct_value, 0, "")? + .into_int_value()) } else { - Ok(self.get_tag_type().const_int(0, false).into()) + Ok(self.get_tag_type().const_int(0, false)) } }