Skip to content

Commit 2d069c9

Browse files
committed
Break build_mask_padding up into two calls to build_mask_padding1
1 parent cab33fa commit 2d069c9

File tree

2 files changed

+87
-83
lines changed

2 files changed

+87
-83
lines changed

hugr-llvm/src/extension/collections/borrow_array.rs

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -586,11 +586,25 @@ fn check_all_mask_eq<'c, H: HugrView<Node = Node>>(
586586
expected: bool,
587587
err: &ConstError,
588588
) -> Result<()> {
589-
build_mask_padding(ctx, mask_info, expected)?;
589+
let end_idx = ctx
590+
.builder()
591+
.build_int_add(mask_info.offset, mask_info.size, "")?;
592+
build_mask_padding1(
593+
ctx,
594+
mask_info.mask_ptr,
595+
mask_info.offset,
596+
expected,
597+
PaddingDirection::First,
598+
)?;
599+
build_mask_padding1(
600+
ctx,
601+
mask_info.mask_ptr,
602+
end_idx,
603+
expected,
604+
PaddingDirection::Last,
605+
)?;
590606

591607
let builder = ctx.builder();
592-
let end_idx = builder.build_int_add(mask_info.offset, mask_info.size, "")?;
593-
594608
let usize_t = usize_ty(&ctx.typing_session());
595609
let expected_val = if expected {
596610
usize_t.const_all_ones()
@@ -624,58 +638,49 @@ fn check_all_mask_eq<'c, H: HugrView<Node = Node>>(
624638
})
625639
}
626640

627-
/// Emits instructions to update the mask, overwriting unused bits with a value.
628-
fn build_mask_padding<'c, H: HugrView<Node = Node>>(
641+
#[derive(Copy, Clone, Debug)]
642+
enum PaddingDirection {
643+
First,
644+
Last,
645+
}
646+
647+
/// Emits instructions to destructively update the first or last block of the mask,
648+
/// given `idx` the first or last element within the mask, overwriting unused bits with `value`.
649+
fn build_mask_padding1<'c, H: HugrView<Node = Node>>(
629650
ctx: &mut EmitFuncContext<'c, '_, H>,
630-
info: &MaskInfo<'c>,
651+
mask_ptr: PointerValue<'c>,
652+
idx: IntValue<'c>,
631653
value: bool,
654+
direction: PaddingDirection,
632655
) -> Result<()> {
633-
let MaskInfo {
634-
mask_ptr,
635-
offset,
636-
size,
637-
} = *info;
638656
let builder = ctx.builder();
639657
let usize_t = usize_ty(&ctx.typing_session());
640658
let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false);
641659

642-
// Find the first and last blocks that contain some used bits
643-
let lst_idx = builder.build_int_add(offset, size, "")?;
644-
let fst_block_idx = builder.build_int_unsigned_div(offset, block_size, "")?;
645-
let lst_block_idx = builder.build_int_unsigned_div(lst_idx, block_size, "")?;
646-
let fst_block_addr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[fst_block_idx], "")? };
647-
let lst_block_addr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[lst_block_idx], "")? };
648-
let fst_block = builder.build_load(fst_block_addr, "")?.into_int_value();
649-
let lst_block = builder.build_load(lst_block_addr, "")?.into_int_value();
660+
let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?;
661+
let block_addr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? };
662+
let block = builder.build_load(block_addr, "")?.into_int_value();
650663

651-
// Pad out the unused bits in the first block
652-
let ones = usize_t.const_all_ones();
653-
let fst_block_unused = builder.build_int_unsigned_rem(offset, block_size, "")?;
654-
let fst_block_used = builder.build_int_sub(block_size, fst_block_unused, "")?;
655-
let new_fst_block = if value {
656-
// Pad with ones
657-
let pad = builder.build_right_shift(ones, fst_block_used, false, "")?;
658-
builder.build_or(fst_block, pad, "")?
659-
} else {
660-
// Pad with zeros
661-
let pad = builder.build_left_shift(ones, fst_block_unused, "")?;
662-
builder.build_and(fst_block, pad, "")?
663-
};
664-
builder.build_store(fst_block_addr, new_fst_block)?;
664+
let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?;
665+
let idx_from_end = builder.build_int_sub(block_size, idx_in_block, "")?;
665666

666-
// Pad out the unused bits in the last block
667-
let lst_block_used = builder.build_int_unsigned_rem(lst_idx, block_size, "")?;
668-
let lst_block_unused = builder.build_int_sub(block_size, lst_block_used, "")?;
669-
let new_lst_block = if value {
667+
let ones = usize_t.const_all_ones();
668+
let new_block = if value {
670669
// Pad with ones
671-
let pad = builder.build_left_shift(ones, lst_block_used, "")?;
672-
builder.build_or(lst_block, pad, "")?
670+
let pad = match direction {
671+
PaddingDirection::First => builder.build_right_shift(ones, idx_from_end, false, "")?,
672+
PaddingDirection::Last => builder.build_left_shift(ones, idx_in_block, "")?,
673+
};
674+
builder.build_or(block, pad, "")?
673675
} else {
674676
// Pad with zeros
675-
let pad = builder.build_right_shift(ones, lst_block_unused, false, "")?;
676-
builder.build_and(lst_block, pad, "")?
677+
let pad = match direction {
678+
PaddingDirection::First => builder.build_left_shift(ones, idx_in_block, "")?,
679+
PaddingDirection::Last => builder.build_right_shift(ones, idx_from_end, false, "")?,
680+
};
681+
builder.build_and(block, pad, "")?
677682
};
678-
builder.build_store(lst_block_addr, new_lst_block)?;
683+
builder.build_store(block_addr, new_block)?;
679684
Ok(())
680685
}
681686

hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__borrow_array__test__emit_clone@[email protected]

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -71,55 +71,54 @@ declare void @llvm.memset.p0i64.i64(i64* nocapture writeonly, i8, i64, i1 immarg
7171
define internal void @__barray_check_none_borrowed(i64* %0, i64 %1, i64 %2) {
7272
%4 = add i64 %1, %2
7373
%5 = udiv i64 %1, 64
74-
%6 = udiv i64 %4, 64
75-
%7 = getelementptr inbounds i64, i64* %0, i64 %5
76-
%8 = getelementptr inbounds i64, i64* %0, i64 %6
77-
%9 = load i64, i64* %7, align 4
78-
%10 = load i64, i64* %8, align 4
79-
%11 = urem i64 %1, 64
80-
%12 = sub i64 64, %11
81-
%13 = shl i64 -1, %11
82-
%14 = and i64 %9, %13
83-
store i64 %14, i64* %7, align 4
74+
%6 = getelementptr inbounds i64, i64* %0, i64 %5
75+
%7 = load i64, i64* %6, align 4
76+
%8 = urem i64 %1, 64
77+
%9 = sub i64 64, %8
78+
%10 = shl i64 -1, %8
79+
%11 = and i64 %7, %10
80+
store i64 %11, i64* %6, align 4
81+
%12 = udiv i64 %4, 64
82+
%13 = getelementptr inbounds i64, i64* %0, i64 %12
83+
%14 = load i64, i64* %13, align 4
8484
%15 = urem i64 %4, 64
8585
%16 = sub i64 64, %15
8686
%17 = lshr i64 -1, %16
87-
%18 = and i64 %10, %17
88-
store i64 %18, i64* %8, align 4
89-
%19 = add i64 %1, %2
90-
%20 = udiv i64 %1, 64
91-
%21 = udiv i64 %19, 64
92-
%22 = sub i64 %21, %20
93-
%23 = add i64 %22, 1
94-
%24 = alloca i64, align 8
95-
store i64 0, i64* %24, align 4
96-
br label %25
97-
98-
25: ; preds = %mask_block_ok, %3
99-
%26 = load i64, i64* %24, align 4
100-
%27 = icmp ult i64 %26, %23
101-
br i1 %27, label %28, label %34
102-
103-
28: ; preds = %25
104-
%29 = load i64, i64* %24, align 4
105-
%30 = add i64 %29, %20
106-
%31 = getelementptr inbounds i64, i64* %0, i64 %30
107-
%32 = load i64, i64* %31, align 4
108-
%33 = icmp eq i64 %32, 0
109-
br i1 %33, label %mask_block_ok, label %mask_block_err
110-
111-
34: ; preds = %25
87+
%18 = and i64 %14, %17
88+
store i64 %18, i64* %13, align 4
89+
%19 = udiv i64 %1, 64
90+
%20 = udiv i64 %4, 64
91+
%21 = sub i64 %20, %19
92+
%22 = add i64 %21, 1
93+
%23 = alloca i64, align 8
94+
store i64 0, i64* %23, align 4
95+
br label %24
96+
97+
24: ; preds = %mask_block_ok, %3
98+
%25 = load i64, i64* %23, align 4
99+
%26 = icmp ult i64 %25, %22
100+
br i1 %26, label %27, label %33
101+
102+
27: ; preds = %24
103+
%28 = load i64, i64* %23, align 4
104+
%29 = add i64 %28, %19
105+
%30 = getelementptr inbounds i64, i64* %0, i64 %29
106+
%31 = load i64, i64* %30, align 4
107+
%32 = icmp eq i64 %31, 0
108+
br i1 %32, label %mask_block_ok, label %mask_block_err
109+
110+
33: ; preds = %24
112111
ret void
113112

114-
mask_block_err: ; preds = %28
115-
%35 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 2, i8* getelementptr inbounds ([39 x i8], [39 x i8]* @0, i32 0, i32 0))
113+
mask_block_err: ; preds = %27
114+
%34 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 2, i8* getelementptr inbounds ([39 x i8], [39 x i8]* @0, i32 0, i32 0))
116115
call void @abort()
117116
unreachable
118117

119-
mask_block_ok: ; preds = %28
120-
%36 = add i64 %29, 1
121-
store i64 %36, i64* %24, align 4
122-
br label %25
118+
mask_block_ok: ; preds = %27
119+
%35 = add i64 %28, 1
120+
store i64 %35, i64* %23, align 4
121+
br label %24
123122
}
124123

125124
declare i32 @printf(i8*, ...)

0 commit comments

Comments
 (0)