Skip to content

Commit

Permalink
Add arith-expand pass to lower ceildiv, floordiv ops (#19200)
Browse files Browse the repository at this point in the history
This PR adds the arith expand pass that decomposes ops like ceildiv and
floordiv into primitive arith ops that can be lowered to LLVM.

---------

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Nov 20, 2024
1 parent 26ef79a commit 8fd3e0d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct ConvertToNVVMPass final
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
arith::populateCeilFloorDivExpandOpsPatterns(llvmPatterns);
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ struct ConvertToROCDLPass final
arith::populateArithToAMDGPUConversionPatterns(
patterns, /*convertFP8Arithmetic=*/true, /*saturateFP8Truncf=*/false,
/*allowPackedF16Rtz=*/false, /*chipset=*/*maybeChipset);
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
populateConvertGPUToAMDGPUPatterns(patterns);
populateConvertSharedMemoryAllocOps(patterns);
populateDropSharedMemoryDeallocOpPatterns(patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,51 @@ hal.executable @ext_fp8_dispatch {
// CDNA3-COUNT-4: rocdl.cvt.f32.bf8 %{{.*}} : f32
// CDNA3: %[[ADD:.+]] = llvm.fadd %{{.*}}, %{{.*}} : vector<4xf32>
// CDNA3: llvm.store %[[ADD]], %{{.*}} : vector<4xf32>, !llvm.ptr<1>

// -----

// Verify that the ceildivsi op gets expanded and lowered successfully all the way to
// the llvm dialect.

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable @ceildiv_expand_dispatch {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export @ceildiv_expand layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @ceildiv_expand() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<16xi32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<readonly:tensor<16xi32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : !flow.dispatch.tensor<writeonly:tensor<16xi32>>
%3 = tensor.empty() : tensor<16xi32>
%4 = flow.dispatch.tensor.load %0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xi32>> -> tensor<16xi32>
%5 = flow.dispatch.tensor.load %1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xi32>> -> tensor<16xi32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xi32>, tensor<16xi32>) outs(%3 : tensor<16xi32>) {
^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
%7 = arith.ceildivsi %arg0, %arg1 : i32
linalg.yield %7 : i32
} -> tensor<16xi32>
flow.dispatch.tensor.store %6, %2, offsets=[0], sizes=[16], strides=[1] : tensor<16xi32> -> !flow.dispatch.tensor<writeonly:tensor<16xi32>>
return
}
}
}
}

// CDNA3-LABEL: hal.executable public @ceildiv_expand_dispatch
// CDNA3: hal.executable.variant public @rocm
// CDNA3-NOT: arith.ceildivsi
// CDNA3-COUNT-1: llvm.select {{.*}} : i1, i32
// CDNA3-COUNT-2: llvm.sdiv {{.*}} : i32
// CDNA3-COUNT-4: llvm.icmp {{.*}} : i32
// CDNA3-COUNT-2: llvm.and {{.*}} : i1
// CDNA3-COUNT-1: llvm.or {{.*}} : i1
// CDNA3-COUNT-1: llvm.select {{.*}} : i1, i32
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@
"onnx/node/generated/test_castlike_FLOAT_to_FLOAT8E5M2FNUZ",
"onnx/node/generated/test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded",
"onnx/node/generated/test_castlike_FLOAT_to_FLOAT8E5M2_expanded",
"onnx/node/generated/test_center_crop_pad_crop",
"onnx/node/generated/test_center_crop_pad_crop_and_pad",
"onnx/node/generated/test_center_crop_pad_crop_axes_chw",
"onnx/node/generated/test_center_crop_pad_crop_axes_hwc",
"onnx/node/generated/test_center_crop_pad_crop_negative_axes_hwc",
"onnx/node/generated/test_center_crop_pad_pad",
"onnx/node/generated/test_col2im",
"onnx/node/generated/test_col2im_5d",
"onnx/node/generated/test_col2im_dilations",
Expand Down Expand Up @@ -446,11 +440,11 @@
"onnx/node/generated/test_reduce_log_sum_exp_default_axes_keepdims_example",
"onnx/node/generated/test_reduce_log_sum_exp_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_log_sum_negative_axes",
"onnx/node/generated/test_reduce_max_bool_inputs",
"onnx/node/generated/test_reduce_mean_default_axes_keepdims_example",
"onnx/node/generated/test_reduce_mean_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_max_bool_inputs",
"onnx/node/generated/test_reduce_min_empty_set",
"onnx/node/generated/test_reduce_min_bool_inputs",
"onnx/node/generated/test_reduce_min_empty_set",
"onnx/node/generated/test_reduce_sum_default_axes_keepdims_example",
"onnx/node/generated/test_reduce_sum_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_sum_empty_axes_input_noop_example",
Expand All @@ -460,8 +454,8 @@
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_example_expanded",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded",
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
"onnx/node/generated/test_resize_downsample_scales_cubic_align_corners",
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
"onnx/node/generated/test_sce_mean_weight",
"onnx/node/generated/test_sce_mean_weight_ii",
"onnx/node/generated/test_sce_mean_weight_ii_log_prob",
Expand Down

0 comments on commit 8fd3e0d

Please sign in to comment.