Skip to content

Commit 439f09c

Browse files
ss2165acl-cqc
authored andcommitted
refactor: allow non-void get_or_make_function (#2609)
1 parent 56eee57 commit 439f09c

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

hugr-llvm/src/emit/func.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,23 +357,32 @@ pub fn build_ok_or_else<'c, H: HugrView<Node = Node>>(
357357
let either = builder.build_select(is_ok, right, left, "")?;
358358
Ok(either)
359359
}
360-
361360
/// Helper to outline LLVM IR into a function call instead of inlining it every time.
362361
///
363362
/// The first time this helper is called with a given function name, a function is built
364363
/// using the provided closure. Future invocations with the same name will just emit calls
365364
/// to this function.
365+
///
366+
/// The return type is specified by `ret_type`, and the closure must return a value of that type.
367+
/// If `ret_type` is `None`, the function is assumed to return void.
366368
pub fn get_or_make_function<'c, H: HugrView<Node = Node>, const N: usize>(
367369
ctx: &mut EmitFuncContext<'c, '_, H>,
368370
func_name: &str,
369371
args: [BasicValueEnum<'c>; N],
370-
go: impl FnOnce(&mut EmitFuncContext<'c, '_, H>, [BasicValueEnum<'c>; N]) -> Result<()>,
371-
) -> Result<()> {
372+
ret_type: Option<BasicTypeEnum<'c>>,
373+
go: impl FnOnce(
374+
&mut EmitFuncContext<'c, '_, H>,
375+
[BasicValueEnum<'c>; N],
376+
) -> Result<Option<BasicValueEnum<'c>>>,
377+
) -> Result<Option<BasicValueEnum<'c>>> {
372378
let func = match ctx.get_current_module().get_function(func_name) {
373379
Some(func) => func,
374380
None => {
375381
let arg_tys = args.iter().map(|v| v.get_type().into()).collect_vec();
376-
let sig = ctx.iw_context().void_type().fn_type(&arg_tys, false);
382+
let sig = match ret_type {
383+
Some(ret_ty) => ret_ty.fn_type(&arg_tys, false),
384+
None => ctx.iw_context().void_type().fn_type(&arg_tys, false),
385+
};
377386
let func =
378387
ctx.get_current_module()
379388
.add_function(func_name, sig, Some(Linkage::Internal));
@@ -388,25 +397,30 @@ pub fn get_or_make_function<'c, H: HugrView<Node = Node>, const N: usize>(
388397

389398
ctx.builder().position_at_end(bb);
390399
ctx.func = func;
391-
go(ctx, args)?;
400+
let ret_val = go(ctx, args)?;
392401
if ctx
393402
.builder()
394403
.get_insert_block()
395404
.unwrap()
396405
.get_terminator()
397406
.is_none()
398407
{
399-
ctx.builder().build_return(None)?;
408+
match ret_val {
409+
Some(ref v) => ctx.builder().build_return(Some(v))?,
410+
None => ctx.builder().build_return(None)?,
411+
};
400412
}
401413

402414
ctx.builder().position_at_end(curr_bb);
403415
ctx.func = curr_func;
404416
func
405417
}
406418
};
407-
ctx.builder()
408-
.build_call(func, &args.iter().map(|&a| a.into()).collect_vec(), "")?;
409-
Ok(())
419+
let call_site =
420+
ctx.builder()
421+
.build_call(func, &args.iter().map(|&a| a.into()).collect_vec(), "")?;
422+
let result = call_site.try_as_basic_value().left();
423+
Ok(result)
410424
}
411425

412426
#[cfg(test)]

hugr-llvm/src/extension/collections/borrow_array.rs

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ fn build_mask_flip<'c, H: HugrView<Node = Node>>(
679679
ctx,
680680
FUNC_NAME,
681681
[mask_ptr.into(), idx.into()],
682+
None,
682683
|ctx, [mask_ptr, idx]| {
683684
let mask_ptr = mask_ptr.into_pointer_value();
684685
let idx = idx.into_int_value();
@@ -694,9 +695,10 @@ fn build_mask_flip<'c, H: HugrView<Node = Node>>(
694695
let update = builder.build_left_shift(usize_t.const_int(1, false), block_offset, "")?;
695696
let block = builder.build_xor(block, update, "")?;
696697
builder.build_store(block_addr, block)?;
697-
Ok(())
698+
Ok(None)
698699
},
699-
)
700+
)?;
701+
Ok(())
700702
}
701703

702704
/// Emits a check that a specific array element has not already been borrowed.
@@ -712,6 +714,7 @@ pub fn build_idx_not_borrowed_check<'c, H: HugrView<Node = Node>>(
712714
ctx,
713715
FUNC_NAME,
714716
[mask_ptr.into(), idx.into()],
717+
None,
715718
|ctx, [mask_ptr, idx]| {
716719
// Emit panic if borrow-bit is set
717720
inspect_mask_idx(
@@ -726,9 +729,11 @@ pub fn build_idx_not_borrowed_check<'c, H: HugrView<Node = Node>>(
726729
ctx.builder().build_unreachable()?;
727730
Ok(())
728731
},
729-
)
732+
)?;
733+
Ok(None)
730734
},
731-
)
735+
)?;
736+
Ok(())
732737
}
733738

734739
/// Emits a check that a specific array index is free.
@@ -744,6 +749,7 @@ pub fn build_idx_free_check<'c, H: HugrView<Node = Node>>(
744749
ctx,
745750
FUNC_NAME,
746751
[mask_ptr.into(), idx.into()],
752+
None,
747753
|ctx, [mask_ptr, idx]| {
748754
// Emit panic if borrow-bit is not set
749755
inspect_mask_idx(
@@ -758,9 +764,11 @@ pub fn build_idx_free_check<'c, H: HugrView<Node = Node>>(
758764
Ok(())
759765
},
760766
|_| Ok(()),
761-
)
767+
)?;
768+
Ok(None)
762769
},
763-
)
770+
)?;
771+
Ok(())
764772
}
765773

766774
/// Emits a check that no array elements have been borrowed.
@@ -782,6 +790,7 @@ pub fn build_none_borrowed_check<'c, H: HugrView<Node = Node>>(
782790
ctx,
783791
FUNC_NAME,
784792
[mask_ptr.into(), offset.into(), size.into()],
793+
None,
785794
|ctx, [mask_ptr, offset, size]| {
786795
let mask_ptr = mask_ptr.into_pointer_value();
787796
let offset = offset.into_int_value();
@@ -806,9 +815,11 @@ pub fn build_none_borrowed_check<'c, H: HugrView<Node = Node>>(
806815
builder.build_conditional_branch(cond, none_borrowed_bb, some_borrowed_bb)?;
807816
builder.position_at_end(none_borrowed_bb);
808817
Ok(())
809-
})
818+
})?;
819+
Ok(None)
810820
},
811-
)
821+
)?;
822+
Ok(())
812823
}
813824

814825
/// Emits a check that all array elements have been borrowed.
@@ -830,6 +841,7 @@ pub fn build_all_borrowed_check<'c, H: HugrView<Node = Node>>(
830841
ctx,
831842
FUNC_NAME,
832843
[mask_ptr.into(), offset.into(), size.into()],
844+
None,
833845
|ctx, [mask_ptr, offset, size]| {
834846
let mask_ptr = mask_ptr.into_pointer_value();
835847
let offset = offset.into_int_value();
@@ -858,9 +870,11 @@ pub fn build_all_borrowed_check<'c, H: HugrView<Node = Node>>(
858870
builder.build_conditional_branch(cond, all_borrowed_bb, some_not_borrowed)?;
859871
builder.position_at_end(all_borrowed_bb);
860872
Ok(())
861-
})
873+
})?;
874+
Ok(None)
862875
},
863-
)
876+
)?;
877+
Ok(())
864878
}
865879

866880
/// Emits a check that a specific array element has not already been borrowed.
@@ -877,6 +891,7 @@ pub fn build_bounds_check<'c, H: HugrView<Node = Node>>(
877891
ctx,
878892
FUNC_NAME,
879893
[size.into(), idx.into()],
894+
None,
880895
|ctx, [size, idx]| {
881896
let size = size.into_int_value();
882897
let idx = idx.into_int_value();
@@ -894,9 +909,10 @@ pub fn build_bounds_check<'c, H: HugrView<Node = Node>>(
894909
ctx.builder()
895910
.build_conditional_branch(in_bounds, ok_bb, err_bb)?;
896911
ctx.builder().position_at_end(ok_bb);
897-
Ok(())
912+
Ok(None)
898913
},
899-
)
914+
)?;
915+
Ok(())
900916
}
901917

902918
/// Helper function to build a loop that repeats for a given number of iterations.

0 commit comments

Comments
 (0)