Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Dec 11, 2024
1 parent 275534b commit 326e043
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 27 deletions.
11 changes: 4 additions & 7 deletions hugr-llvm/src/emit/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use hugr_core::{
types::{SumType, Type, TypeEnum},
HugrView, NodeIndex,
};
use inkwell::types::{BasicTypeEnum, IntType};
use inkwell::values::{BasicValueEnum, CallableValue, IntValue};
use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValueEnum, CallableValue};
use itertools::{zip_eq, Itertools};
use petgraph::visit::Walker;

Expand All @@ -21,7 +21,7 @@ use crate::{

use super::{
deaggregate_call_result,
func::{EmitFuncContext, RowMailBox, RowPromise},
func::{EmitFuncContext, RowPromise},
EmitOpArgs,
};

Expand All @@ -31,7 +31,6 @@ struct DataflowParentEmitter<'c, 'hugr, OT, H> {
node: FatNode<'hugr, OT, H>,
inputs: Option<Vec<BasicValueEnum<'c>>>,
outputs: Option<RowPromise<'c>>,
output_vals: Option<Vec<BasicValueEnum<'c>>>,
}

impl<'c, 'hugr, OT: OpTrait, H: HugrView> DataflowParentEmitter<'c, 'hugr, OT, H>
Expand All @@ -43,7 +42,6 @@ where
node: args.node,
inputs: Some(args.inputs),
outputs: Some(args.outputs),
output_vals: None,
}
}

Expand Down Expand Up @@ -310,13 +308,12 @@ fn emit_cfg<'c, H: HugrView>(

fn emit_tail_loop<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, TailLoop, H>
args: EmitOpArgs<'c, '_, TailLoop, H>,
) -> Result<()> {
// TODO: Switch on the tag in loop_body to see where to go next
// TODO: Handle "other" args
let node = args.node();


// Make a block to jump to when we `Break`
let out_bb = context.new_basic_block("loop_out", None);
// A block for the body of the loop
Expand Down
79 changes: 59 additions & 20 deletions hugr-llvm/src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use anyhow::{anyhow, Result};
use hugr_core::builder::{
BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer,
};
use hugr_core::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T};
use hugr_core::extension::prelude::PRELUDE_ID;
use hugr_core::extension::{ExtensionRegistry, ExtensionSet, EMPTY_REG};
use hugr_core::ops::handle::FuncID;
use hugr_core::std_extensions::arithmetic::{
conversions, float_ops, float_types, int_ops, int_types,
};
use hugr_core::std_extensions::{collections, logic};
use hugr_core::types::{SumType, TypeRow};
use hugr_core::types::TypeRow;
use hugr_core::{Hugr, HugrView};
use inkwell::module::Module;
use inkwell::passes::PassManager;
Expand Down Expand Up @@ -246,7 +246,7 @@ mod test_fns {
use super::*;
use crate::custom::CodegenExtsBuilder;
use crate::extension::int::add_int_extensions;
use crate::types::HugrFuncType;
use crate::types::{HugrFuncType, HugrSumType};

use hugr_core::builder::DataflowSubContainer;
use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer};
Expand Down Expand Up @@ -557,20 +557,24 @@ mod test_fns {
let hugr = {
let just_input = USIZE_T;
let just_output = Type::UNIT;
let sum_ty = SumType::new(vec![just_input.clone(), just_output.clone()]);
let input_v = TypeRow::from(vec![just_input.clone()]);
let output_v = TypeRow::from(vec![just_output.clone()]);

llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);


SimpleHugrConfig::new()
.with_extensions(PRELUDE_REGISTRY.clone())
.with_ins(input_v)
.with_outs(output_v)
.finish(|mut builder: DFGW| {
let [just_in_w] = builder.input_wires_arr();
let mut tail_b = builder.tail_loop_builder([(just_input.clone(), just_in_w)], [], vec![just_output.clone()].into()).unwrap();
let mut tail_b = builder
.tail_loop_builder(
[(just_input.clone(), just_in_w)],
[],
vec![just_output.clone()].into(),
)
.unwrap();

let input = tail_b.input();
let [inp_w] = input.outputs_arr();
Expand All @@ -581,12 +585,14 @@ mod test_fns {

let sum_inp_w = tail_b.make_continue(loop_sig.clone(), [inp_w]).unwrap();

let outs@[_] = tail_b.finish_with_outputs(sum_inp_w, []).unwrap().outputs_arr();
let outs @ [_] = tail_b
.finish_with_outputs(sum_inp_w, [])
.unwrap()
.outputs_arr();
builder.finish_with_outputs(outs).unwrap()
})
};
check_emission!(hugr, llvm_ctx);

}

#[rstest]
Expand All @@ -597,7 +603,7 @@ mod test_fns {
let just_input = int_ty.clone();
let just_output = Type::UNIT;
let other_ty = int_ty.clone();
let sum_ty = SumType::new(vec![just_input.clone(), just_output.clone()]);
let sum_ty = HugrSumType::new(vec![just_input.clone(), just_output.clone()]);
let input_v = TypeRow::from(vec![just_input.clone(), other_ty.clone()]);
let output_v = TypeRow::from(vec![just_output.clone(), other_ty.clone()]);

Expand All @@ -613,12 +619,24 @@ mod test_fns {
.with_outs(output_v)
.finish(|mut builder: DFGW| {
let [just_in_w, other_w] = builder.input_wires_arr();
let mut tail_b = builder.tail_loop_builder([(just_input.clone(), just_in_w)], [(other_ty.clone(), other_w)], vec![just_output.clone()].into()).unwrap();
let [sum_inp_w, other_w] = tail_b.input_wires_arr();
let mut tail_b = builder
.tail_loop_builder(
[(just_input.clone(), just_in_w)],
[(other_ty.clone(), other_w)],
vec![just_output.clone()].into(),
)
.unwrap();
let [sum_inp_w, _other_w] = tail_b.input_wires_arr();

let zero = ConstInt::new_u(6, 0).unwrap();
let zero_w = tail_b.add_load_value(zero);
let [result] = tail_b.add_dataflow_op(int_ops::IntOpDef::ile_u.with_log_width(6), [sum_inp_w, zero_w]).unwrap().outputs_arr();
let [_result] = tail_b
.add_dataflow_op(
int_ops::IntOpDef::ile_u.with_log_width(6),
[sum_inp_w, zero_w],
)
.unwrap()
.outputs_arr();
let input = tail_b.input();
let [inp_w, other_w] = input.outputs_arr();

Expand All @@ -643,31 +661,52 @@ mod test_fns {

let two = ConstInt::new_u(6, 2).unwrap();
let two_w = false_case_b.add_load_value(two);
let [val] = false_case_b.add_dataflow_op(int_ops::IntOpDef::imul.with_log_width(6), [val, two_w]).unwrap().outputs_arr();
let [val] = false_case_b
.add_dataflow_op(
int_ops::IntOpDef::imul.with_log_width(6),
[val, two_w],
)
.unwrap()
.outputs_arr();

let [counter] = false_case_b.add_dataflow_op(int_ops::IntOpDef::isub.with_log_width(6), [counter, one_w]).unwrap().outputs_arr();
let tagged_counter = false_case_b.make_continue(loop_sig.clone(), [counter]).unwrap();
let [counter] = false_case_b
.add_dataflow_op(
int_ops::IntOpDef::isub.with_log_width(6),
[counter, one_w],
)
.unwrap()
.outputs_arr();
let tagged_counter = false_case_b
.make_continue(loop_sig.clone(), [counter])
.unwrap();

false_case_b.finish_with_outputs([tagged_counter, val]).unwrap();
false_case_b
.finish_with_outputs([tagged_counter, val])
.unwrap();

// In the true case, we break and output true along with the "other" input wire
let mut true_case_b = cond_b.case_builder(1).unwrap();

let [_, val_w] = true_case_b.input_wires_arr();
let unit_v = Value::unit_sum(0, 1).unwrap();
let unit_w = true_case_b.add_load_value(unit_v);
let tagged_output = true_case_b.make_break(loop_sig.clone(), [unit_w]).unwrap();
true_case_b.finish_with_outputs([tagged_output, val_w]).unwrap();
let tagged_output =
true_case_b.make_break(loop_sig.clone(), [unit_w]).unwrap();
true_case_b
.finish_with_outputs([tagged_output, val_w])
.unwrap();

cond_b.finish_sub_container().unwrap()
};
let [sum, rest] = cond.outputs_arr();
let outs@[_,_] = tail_b.finish_with_outputs(sum, [rest]).unwrap().outputs_arr();
let outs @ [_, _] = tail_b
.finish_with_outputs(sum, [rest])
.unwrap()
.outputs_arr();
builder.finish_with_outputs(outs).unwrap()
})
};
check_emission!(hugr, llvm_ctx);

}

#[rstest]
Expand Down

0 comments on commit 326e043

Please sign in to comment.