@@ -586,11 +586,25 @@ fn check_all_mask_eq<'c, H: HugrView<Node = Node>>(
586
586
expected : bool ,
587
587
err : & ConstError ,
588
588
) -> Result < ( ) > {
589
- build_mask_padding ( ctx, mask_info, expected) ?;
589
+ let end_idx = ctx
590
+ . builder ( )
591
+ . build_int_add ( mask_info. offset , mask_info. size , "" ) ?;
592
+ build_mask_padding1 (
593
+ ctx,
594
+ mask_info. mask_ptr ,
595
+ mask_info. offset ,
596
+ expected,
597
+ PaddingDirection :: First ,
598
+ ) ?;
599
+ build_mask_padding1 (
600
+ ctx,
601
+ mask_info. mask_ptr ,
602
+ end_idx,
603
+ expected,
604
+ PaddingDirection :: Last ,
605
+ ) ?;
590
606
591
607
let builder = ctx. builder ( ) ;
592
- let end_idx = builder. build_int_add ( mask_info. offset , mask_info. size , "" ) ?;
593
-
594
608
let usize_t = usize_ty ( & ctx. typing_session ( ) ) ;
595
609
let expected_val = if expected {
596
610
usize_t. const_all_ones ( )
@@ -624,58 +638,49 @@ fn check_all_mask_eq<'c, H: HugrView<Node = Node>>(
624
638
} )
625
639
}
626
640
627
- /// Emits instructions to update the mask, overwriting unused bits with a value.
628
- fn build_mask_padding < ' c , H : HugrView < Node = Node > > (
641
+ #[ derive( Copy , Clone , Debug ) ]
642
+ enum PaddingDirection {
643
+ First ,
644
+ Last ,
645
+ }
646
+
647
+ /// Emits instructions to destructively update the first or last block of the mask,
648
+ /// given `idx` the first or last element within the mask, overwriting unused bits with `value`.
649
+ fn build_mask_padding1 < ' c , H : HugrView < Node = Node > > (
629
650
ctx : & mut EmitFuncContext < ' c , ' _ , H > ,
630
- info : & MaskInfo < ' c > ,
651
+ mask_ptr : PointerValue < ' c > ,
652
+ idx : IntValue < ' c > ,
631
653
value : bool ,
654
+ direction : PaddingDirection ,
632
655
) -> Result < ( ) > {
633
- let MaskInfo {
634
- mask_ptr,
635
- offset,
636
- size,
637
- } = * info;
638
656
let builder = ctx. builder ( ) ;
639
657
let usize_t = usize_ty ( & ctx. typing_session ( ) ) ;
640
658
let block_size = usize_t. const_int ( usize_t. get_bit_width ( ) as u64 , false ) ;
641
659
642
- // Find the first and last blocks that contain some used bits
643
- let lst_idx = builder. build_int_add ( offset, size, "" ) ?;
644
- let fst_block_idx = builder. build_int_unsigned_div ( offset, block_size, "" ) ?;
645
- let lst_block_idx = builder. build_int_unsigned_div ( lst_idx, block_size, "" ) ?;
646
- let fst_block_addr = unsafe { builder. build_in_bounds_gep ( mask_ptr, & [ fst_block_idx] , "" ) ? } ;
647
- let lst_block_addr = unsafe { builder. build_in_bounds_gep ( mask_ptr, & [ lst_block_idx] , "" ) ? } ;
648
- let fst_block = builder. build_load ( fst_block_addr, "" ) ?. into_int_value ( ) ;
649
- let lst_block = builder. build_load ( lst_block_addr, "" ) ?. into_int_value ( ) ;
660
+ let block_idx = builder. build_int_unsigned_div ( idx, block_size, "" ) ?;
661
+ let block_addr = unsafe { builder. build_in_bounds_gep ( mask_ptr, & [ block_idx] , "" ) ? } ;
662
+ let block = builder. build_load ( block_addr, "" ) ?. into_int_value ( ) ;
650
663
651
- // Pad out the unused bits in the first block
652
- let ones = usize_t. const_all_ones ( ) ;
653
- let fst_block_unused = builder. build_int_unsigned_rem ( offset, block_size, "" ) ?;
654
- let fst_block_used = builder. build_int_sub ( block_size, fst_block_unused, "" ) ?;
655
- let new_fst_block = if value {
656
- // Pad with ones
657
- let pad = builder. build_right_shift ( ones, fst_block_used, false , "" ) ?;
658
- builder. build_or ( fst_block, pad, "" ) ?
659
- } else {
660
- // Pad with zeros
661
- let pad = builder. build_left_shift ( ones, fst_block_unused, "" ) ?;
662
- builder. build_and ( fst_block, pad, "" ) ?
663
- } ;
664
- builder. build_store ( fst_block_addr, new_fst_block) ?;
664
+ let idx_in_block = builder. build_int_unsigned_rem ( idx, block_size, "" ) ?;
665
+ let idx_from_end = builder. build_int_sub ( block_size, idx_in_block, "" ) ?;
665
666
666
- // Pad out the unused bits in the last block
667
- let lst_block_used = builder. build_int_unsigned_rem ( lst_idx, block_size, "" ) ?;
668
- let lst_block_unused = builder. build_int_sub ( block_size, lst_block_used, "" ) ?;
669
- let new_lst_block = if value {
667
+ let ones = usize_t. const_all_ones ( ) ;
668
+ let new_block = if value {
670
669
// Pad with ones
671
- let pad = builder. build_left_shift ( ones, lst_block_used, "" ) ?;
672
- builder. build_or ( lst_block, pad, "" ) ?
670
+ let pad = match direction {
671
+ PaddingDirection :: First => builder. build_right_shift ( ones, idx_from_end, false , "" ) ?,
672
+ PaddingDirection :: Last => builder. build_left_shift ( ones, idx_in_block, "" ) ?,
673
+ } ;
674
+ builder. build_or ( block, pad, "" ) ?
673
675
} else {
674
676
// Pad with zeros
675
- let pad = builder. build_right_shift ( ones, lst_block_unused, false , "" ) ?;
676
- builder. build_and ( lst_block, pad, "" ) ?
677
+ let pad = match direction {
678
+ PaddingDirection :: First => builder. build_left_shift ( ones, idx_in_block, "" ) ?,
679
+ PaddingDirection :: Last => builder. build_right_shift ( ones, idx_from_end, false , "" ) ?,
680
+ } ;
681
+ builder. build_and ( block, pad, "" ) ?
677
682
} ;
678
- builder. build_store ( lst_block_addr , new_lst_block ) ?;
683
+ builder. build_store ( block_addr , new_block ) ?;
679
684
Ok ( ( ) )
680
685
}
681
686
0 commit comments