From dc29ee7d1bcfcec5a58d42e29125bbda937bbbbc Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Sat, 14 Dec 2024 21:13:16 -0500 Subject: [PATCH] Move GPU ukernel selection to KernelConfig (#19440) This moves the logic deciding whether an op should be a ukernel out of the GPULowerToUKernels pass, into KernelConfig. So KernelConfig decides whether the op should be a ukernel, and encodes that into the resulting `lowering_config`, in a new parameter, that is a new attribute, UKernelSpecAttr. That attribute is directly modeled after the equivalent C++ data structure that we have had in LowerToUKernels passes, `FnNameAndDefAttrs`, which it replaces. If the attribute is present, it means that the op was selected for ukernel lowering, with the fields telling the ukernel name and some function definition attributes (to import any dependencies, such as the `rocm` module for runtime support symbols). All the details about supplying the ukernel bitcode in a `hal.executable.object` are also moved there, becoming a side effect of `KernelConfig`. The GPULowerToUKernels becomes much simpler, since all the decision-making was already done for it. It just looks at the `LoweringConfigAttr` and if it's there, it performs the requested lowering. The motivation for this split is that we need to know in KernelConfig whether it's going to be a ukernel, because ops that will get lowered to a ukernel require a different configuration. The important example for us is `multi_mma`, which in the ukernel case needs to avoid reduction-dimension tiling to 1 so that the ukernel gets to see the reduction loop. A few simplifications arise already in the current argmax ukernel logic, confirming that this was the right design choice: the old ukernel's matching logic was checking that the distribution tile sizes matched what the ukernel could handle; now that is turned upside down: the ukernel matching happens as a helper within KernelConfig where we know we are setting the appropriate tile sizes on purpose. Another nice improvement is that this puts just enough distance between ukernel selection (which creates the `hal.executable.object`) and ukernel lowering, that we are able to insert `HoistExecutableObjectsPass` in between, simplifying the ukernel lowering as it doesn't need to worry anymore about preserving the `hal.executable.object`. --------- Signed-off-by: Benoit Jacob --- compiler/plugins/target/ROCM/test/BUILD.bazel | 3 +- .../plugins/target/ROCM/test/CMakeLists.txt | 3 +- .../test/config_ukernel_argmax_gfx908.mlir | 30 +++ ...mlir => config_ukernel_argmax_gfx942.mlir} | 177 +++++------------- .../ROCM/test/ukernel_pipeline_transform.mlir | 4 +- .../Codegen/Common/GPU/GPULowerToUKernels.cpp | 154 ++------------- .../compiler/Codegen/Common/GPU/Passes.td | 2 +- .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + .../GPU/test/gpu_lower_to_ukernels.mlir | 72 +++++++ .../Dialect/GPU/IR/GPULoweringConfigUtils.cpp | 5 + .../Dialect/GPU/IR/GPULoweringConfigUtils.h | 2 + .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.td | 19 ++ .../iree/compiler/Codegen/LLVMGPU/BUILD.bazel | 1 + .../compiler/Codegen/LLVMGPU/CMakeLists.txt | 1 + .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 41 ++-- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 5 + .../Codegen/LLVMGPU/Utils/BUILD.bazel | 4 + .../Codegen/LLVMGPU/Utils/CMakeLists.txt | 4 + .../LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp | 152 +++++++++++++++ .../LLVMGPU/Utils/LLVMGPUSelectUKernels.h | 15 ++ 21 files changed, 392 insertions(+), 304 deletions(-) create mode 100644 compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir rename compiler/plugins/target/ROCM/test/{gpu_lower_to_ukernels.mlir => config_ukernel_argmax_gfx942.mlir} (58%) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp create mode 100644 compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h diff --git a/compiler/plugins/target/ROCM/test/BUILD.bazel b/compiler/plugins/target/ROCM/test/BUILD.bazel index f0521a0e8c50..2a71f590c6e3 100644 --- a/compiler/plugins/target/ROCM/test/BUILD.bazel +++ b/compiler/plugins/target/ROCM/test/BUILD.bazel @@ -15,8 +15,9 @@ package( iree_lit_test_suite( name = "lit", srcs = [ + "config_ukernel_argmax_gfx908.mlir", + "config_ukernel_argmax_gfx942.mlir", "default_tuning_specs_amdgpu.mlir", - "gpu_lower_to_ukernels.mlir", "lowering_strategy_from_tuning_spec.mlir", "ukernel_pipeline_transform.mlir", ], diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt index 36d9ba6db31d..bab88582a8b0 100644 --- a/compiler/plugins/target/ROCM/test/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt @@ -14,8 +14,9 @@ iree_lit_test_suite( NAME lit SRCS + "config_ukernel_argmax_gfx908.mlir" + "config_ukernel_argmax_gfx942.mlir" "default_tuning_specs_amdgpu.mlir" - "gpu_lower_to_ukernels.mlir" "lowering_strategy_from_tuning_spec.mlir" "ukernel_pipeline_transform.mlir" TOOLS diff --git a/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir new file mode 100644 index 000000000000..ba12bf5e10f6 --- /dev/null +++ b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx908.mlir @@ -0,0 +1,30 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s + +// gfx908 a.k.a. CDNA1 is used here as an example of a GPU target that we don't have ukernels for. +// No need to add many ukernels here, just a quick check that we correctly do not select a ukernel. + +func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { + hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> +} { + %c0_i64 = arith.constant 0 : i64 + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<1xi64> + %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64> + %2 = tensor.empty() : tensor<1xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32> + %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) { + ^bb0(%in: f32, %out: f32, %out_0: i64): + %5 = linalg.index 1 : index + %6 = arith.index_cast %5 : index to i64 + %7 = arith.maximumf %in, %out : f32 + %8 = arith.cmpf ogt, %in, %out : f32 + %9 = arith.select %8, %6, %out_0 : i64 + linalg.yield %7, %9 : f32, i64 + } -> (tensor<1xf32>, tensor<1xi64>) + return %4#1 : tensor<1xi64> +} + +// CHECK-NOT: lowering_config<{{.*}}ukernel +// CHECK-LABEL: func @argmax_2d_f32i64( +// CHECK: linalg.generic +// CHECK-NOT: hal.executable.objects diff --git a/compiler/plugins/target/ROCM/test/gpu_lower_to_ukernels.mlir b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx942.mlir similarity index 58% rename from compiler/plugins/target/ROCM/test/gpu_lower_to_ukernels.mlir rename to compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx942.mlir index 177bd0b36f7c..4a7da4befadd 100644 --- a/compiler/plugins/target/ROCM/test/gpu_lower_to_ukernels.mlir +++ b/compiler/plugins/target/ROCM/test/config_ukernel_argmax_gfx942.mlir @@ -1,5 +1,4 @@ -// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s -// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx908 --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s --check-prefix=CDNA1 +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> @@ -22,15 +21,11 @@ func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes return %4#1 : tensor<1xi64> } -//CHECK-LABEL: func @argmax_2d_f32i64( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32> -// CHECK-DAG: %[[C1_index:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C0_i64:.+]] = arith.constant 0 -// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]] -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f32i64" -// CHECK-SAME: ins(%[[ARG0]] : -// CHECK-SAME: outs(%[[FILL]] : -// CHECK: return %[[MICRO_KERNEL]] +// CHECK-LABEL: func @argmax_2d_f32i64( +// CHECK: linalg.generic +// CHECK-SAME: hal.executable.objects = [ +// CEHCK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource : vector<{{[0-9]+}}xi8>}>] +// CHECK-SAME: #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec // ----- @@ -55,65 +50,11 @@ func.func @argmax_4d_unit_parallel_f32i64(%arg0 : tensor<1x1x1x?xf32>) -> tensor return %4#1 : tensor<1x1x1xi64> } -// CHECK-LABEL: func @argmax_4d_unit_parallel_f32i64( -// CHECK: iree_codegen.ukernel.generic -// CHECK-NOT: linalg.generic - -// ----- - -func.func @argmax_2d_non_unit_parallel_f32i64(%arg0 : tensor<4x?xf32>) -> tensor<4xi64> attributes { - hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> -} { - %c0_i64 = arith.constant 0 : i64 - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<4xi64> - %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<4xi64>) -> tensor<4xi64> - %2 = tensor.empty() : tensor<4xf32> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4xf32>) -> tensor<4xf32> - %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x?xf32>) outs(%3, %1 : tensor<4xf32>, tensor<4xi64>) { - ^bb0(%in: f32, %out: f32, %out_0: i64): - %5 = linalg.index 1 : index - %6 = arith.index_cast %5 : index to i64 - %7 = arith.maximumf %in, %out : f32 - %8 = arith.cmpf ogt, %in, %out : f32 - %9 = arith.select %8, %6, %out_0 : i64 - linalg.yield %7, %9 : f32, i64 - } -> (tensor<4xf32>, tensor<4xi64>) - return %4#1 : tensor<4xi64> -} - -// CHECK-LABEL: func @argmax_2d_non_unit_parallel_f32i64( -// CHECK-NOT: iree_codegen.ukernel.generic -// CHECK: linalg.generic - -// ----- - -func.func @argmax_2d_dyn_parallel_f32i64(%arg0 : tensor) -> tensor attributes { - hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> -} { - %c0 = arith.constant 0 : index - %c0_i64 = arith.constant 0 : i64 - %cst = arith.constant 0xFF800000 : f32 - %dim = tensor.dim %arg0, %c0 : tensor - %0 = tensor.empty(%dim) : tensor - %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor) -> tensor - %2 = tensor.empty(%dim) : tensor - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor) -> tensor - %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%3, %1 : tensor, tensor) { - ^bb0(%in: f32, %out: f32, %out_0: i64): - %5 = linalg.index 1 : index - %6 = arith.index_cast %5 : index to i64 - %7 = arith.maximumf %in, %out : f32 - %8 = arith.cmpf ogt, %in, %out : f32 - %9 = arith.select %8, %6, %out_0 : i64 - linalg.yield %7, %9 : f32, i64 - } -> (tensor, tensor) - return %4#1 : tensor -} - -// CHECK-LABEL: func @argmax_2d_dyn_parallel_f32i64( -// CHECK-NOT: iree_codegen.ukernel.generic -// CHECK: linalg.generic +// CHECK-LABEL: func @argmax_4d_unit_parallel_f32i64( +// CHECK: linalg.generic +// CHECK-SAME: hal.executable.objects = [ +// CEHCK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource : vector<{{[0-9]+}}xi8>}>] +// CHECK-SAME: #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec // ----- @@ -138,9 +79,10 @@ func.func @argmax_none_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> return %4#1 : tensor<1xi64> } -// CHECK-LABEL: func @argmax_none_ukernel_enabled( -// CHECK-NOT: iree_codegen.ukernel.generic -// CHECK: linalg.generic +// CHECK-LABEL: func @argmax_none_ukernel_enabled( +// CHECK: linalg.generic +// CHECK-NOT: hal.executable.objects +// CHECK-NOT: iree_gpu.ukernel_spec // ----- @@ -165,9 +107,11 @@ func.func @argmax_only_argmax_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor return %4#1 : tensor<1xi64> } -// CDNA2-LABEL: func @argmax_only_argmax_ukernel_enabled( -// CDNA2: iree_codegen.ukernel.generic -// CDNA2-NOT: linalg.generic +// CHECK-LABEL: func @argmax_only_argmax_ukernel_enabled( +// CHECK: linalg.generic +// CHECK-SAME: hal.executable.objects = [ +// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource : vector<{{[0-9]+}}xi8>}>] +// CHECK-SAME: #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec // ----- @@ -192,11 +136,11 @@ func.func @argmax_only_foo_argmax_bar_ukernel_enabled(%arg0 : tensor<1x?xf32>) - return %4#1 : tensor<1xi64> } -// CHECK-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled( -// CHECK: iree_codegen.ukernel.generic -// CHECK-NOT: linalg.generic - -// CDNA2-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled( +// CHECK-LABEL: func @argmax_only_foo_argmax_bar_ukernel_enabled( +// CHECK: linalg.generic +// CHECK-SAME: hal.executable.objects = [ +// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", data = dense_resource : vector<{{[0-9]+}}xi8>}>] +// CHECK-SAME: #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec // ----- @@ -221,9 +165,10 @@ func.func @argmax_only_foo_ukernel_enabled(%arg0 : tensor<1x?xf32>) -> tensor<1x return %4#1 : tensor<1xi64> } -// CHECK-LABEL: func @argmax_only_foo_ukernel_enabled( -// CHECK-NOT: iree_codegen.ukernel.generic -// CHECK: linalg.generic +// CHECK-LABEL: func @argmax_only_foo_ukernel_enabled( +// CHECK: linalg.generic +// CHECK-NOT: hal.executable.objects +// CHECK-NOT: iree_gpu.ukernel_spec // ----- @@ -249,46 +194,16 @@ func.func @argmax_2d_f32i64_not_neg_inf_init(%arg0 : tensor<1x?xf32>) -> tensor< return %4#1 : tensor<1xi64> } -// CHECK-LABEL: func @argmax_2d_f32i64_not_neg_inf_init( -// CHECK-NOT: iree_codegen.ukernel.generic -// CHECK: linalg.generic - -// ----- - -// TODO: No technical reason this architecture is not supported. -// Currently just picking out popular chips to support, -// to minimize compile time and space. - -func.func @argmax_ukernel_unsupported_arch(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { - hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> -} { - %c0_i64 = arith.constant 0 : i64 - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<1xi64> - %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64> - %2 = tensor.empty() : tensor<1xf32> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32> - %4:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) { - ^bb0(%in: f32, %out: f32, %out_0: i64): - %5 = linalg.index 1 : index - %6 = arith.index_cast %5 : index to i64 - %7 = arith.maximumf %in, %out : f32 - %8 = arith.cmpf ogt, %in, %out : f32 - %9 = arith.select %8, %6, %out_0 : i64 - linalg.yield %7, %9 : f32, i64 - } -> (tensor<1xf32>, tensor<1xi64>) - return %4#1 : tensor<1xi64> -} - -// CDNA1-LABEL: func @argmax_ukernel_unsupported_arch( -// CDNA1-NOT: iree_codegen.ukernel.generic -// CDNA1: linalg.generic +// CHECK-NOT: lowering_config<{{.*}}ukernel +// CHECK-LABEL: func @argmax_2d_f32i64_not_neg_inf_init( +// CHECK: linalg.generic +// CHECK-NOT: hal.executable.objects // ----- // Test user-provided bitcode in the source IR. -func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { +func.func @argmax_2d_f32i64_custom_bitcode(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>, // Dummy bitcode with an unusual length of 12. The first 4 bytes are the .bc file format signature. hal.executable.objects = [ @@ -316,18 +231,12 @@ func.func @argmax_2d_f32i64(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes return %4#1 : tensor<1xi64> } -//CHECK-LABEL: func @argmax_2d_f32i64( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32> -// CHECK-DAG: %[[C1_index:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C0_i64:.+]] = arith.constant 0 -// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]] -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic { -// CHECK-SAME: hal.executable.objects = [ -// CHECK-SAME: #hal.executable.object<{ -// CHECK-SAME: path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", -// CHECK-SAME: data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8> -// CHECK-SAME: }> -// CHECK-SAME: ]} "iree_uk_amdgpu_argmax_f32i64" -// CHECK-SAME: ins(%[[ARG0]] : -// CHECK-SAME: outs(%[[FILL]] : -// CHECK: return %[[MICRO_KERNEL]] +// CHECK-LABEL: func @argmax_2d_f32i64_custom_bitcode( +// CHECK: linalg.generic +// CHECK-SAME: hal.executable.objects = [ +// CHECK-SAME: #hal.executable.object<{ +// CHECK-SAME: path = "iree_uk_amdgpu_argmax_f32i64.gfx942.bc", +// CHECK-SAME: data = dense<[66, 67, -64, -34, 1, 35, 69, 103, -119, -85, -51, -17]> : tensor<12xi8> +// CHECK-SAME: }> +// CHECK-SAME: ] +// CHECK-SAME: #iree_gpu.lowering_config<{{.*}}ukernel = #iree_gpu.ukernel_spec diff --git a/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir b/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir index 26ce4c8959f4..15e5169e37b5 100644 --- a/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir +++ b/compiler/plugins/target/ROCM/test/ukernel_pipeline_transform.mlir @@ -44,7 +44,7 @@ func.func @argmax_1d_f16i64() attributes { // CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info // CHECK: func.func @argmax_1d_f16i64() // CHECK-SAME: translation_info = #[[$TRANSLATION]] -// CHECK: iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f16i64" +// CHECK: iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f16i64" // ----- @@ -94,7 +94,7 @@ func.func @argmax_2d_f32i64() attributes { // CHECK-SAME: translation_info = #[[$TRANSLATION]] // CHECK: %[[SUBVIEW:.*]] = memref.subview{{.*}} memref<16x?xf32 // CHECK-SAME: to memref<1x?xf32 -// CHECK: iree_codegen.ukernel.generic {hal.executable.objects = [{{.*}}]} "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]] +// CHECK: iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64" ins(%[[SUBVIEW]] // ----- diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp index c9ff4b8ed96c..796138d55e3f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp @@ -5,12 +5,11 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h" -#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" -#include "iree/compiler/Utils/EmbeddedDataDirectory.h" -#include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -27,114 +26,12 @@ namespace mlir::iree_compiler { namespace { -// Returns a ExecutableObjectAttr carrying the bitcode for the given ukernel. -// -// First tries finding the bitcode in the input `sourceExecutableObjects`, which -// must be an array of ExecutableObjectAttr's and is typically coming from a -// hal.executable.objects array attribute in the source IR, which is the -// mechanism by which source programs may provide their own ukernel bitcode. -// -// If no matching bitcode was found in `sourceExecutableObjects`, this function -// will then search in bitcode files that we have embedded as static data. -static IREE::HAL::ExecutableObjectAttr -getUKernelBitcode(OpBuilder &builder, - IREE::HAL::ExecutableTargetAttr execTarget, - ArrayAttr sourceExecutableObjects, StringRef ukernelName) { - IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(execTarget); - if (!gpuTarget) { - return {}; - } - StringRef gpuArch = gpuTarget.getArch(); - std::string bitcodeFilename = llvm::formatv("{}.{}.bc", ukernelName, gpuArch); - - // Early-return if the source executable.objects already contain an object - // with the expected file name. This happens with user-provided bitcode in the - // source IR. - if (sourceExecutableObjects) { - for (Attribute a : sourceExecutableObjects) { - if (auto object = dyn_cast(a)) { - if (object.getPath() == bitcodeFilename) { - return object; - } - } - } - } - - // No user-provided bitcode, so we search our embedded bitcode files in the - // EmbeddedDataDirectory singleton. - std::optional bitcode; - EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) { - bitcode = dir.getFile(bitcodeFilename); - }); - if (!bitcode) { - return {}; - } - MLIRContext *context = builder.getContext(); - auto blob = HeapAsmResourceBlob::allocateAndCopyInferAlign( - ArrayRef(bitcode->data(), bitcode->size())); - auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get( - VectorType::get({static_cast(bitcode->size())}, - builder.getI8Type()), - bitcodeFilename, std::move(blob)); - return IREE::HAL::ExecutableObjectAttr::get( - context, StringAttr::get(context, bitcodeFilename), - cast(bitcodeDenseAttr)); -} - -// Walks parents ops from `op` to return the nearest hal.executable.objects -// array attribute. If the parent hal.executable.variant is reached, its objects -// attribute is returned. -// Adapted from ExecutableTargetAttr::lookup. -static ArrayAttr lookUpExecutableObjects(Operation *op) { - MLIRContext *context = op->getContext(); - auto attrId = StringAttr::get(context, "hal.executable.objects"); - while (op) { - // Take directly from the enclosing variant. - if (auto variantOp = dyn_cast(op)) { - if (std::optional objects = variantOp.getObjects()) { - return *objects; - } - } - // Take from op attributes. - if (auto attr = op->getAttrOfType(attrId)) { - return attr; - } - // Continue walk. - op = op->getParentOp(); - } - return {}; -} - -/// Holds a function name and attributes. -struct FnNameAndDefAttrs { - std::string name; - SmallVector defAttrs; - explicit operator bool() const { return !name.empty(); } -}; - -/// Returns the function name and attributes to use for a ukernel with given -/// `name` and `suffix` on the target described by `targetAttr`. -static FnNameAndDefAttrs -getFnNameAndDefAttrs(const char *name, std::string &suffix, - RewriterBase &rewriter, - IREE::HAL::ExecutableTargetAttr targetAttr) { - FnNameAndDefAttrs result; - if (isROCMBackend(targetAttr)) { - result.name = llvm::formatv("iree_uk_amdgpu_{}_{}", name, suffix); - result.defAttrs.emplace_back(rewriter.getStringAttr("vm.import.module"), - rewriter.getStringAttr("rocm")); - } - return result; -} - /// Matches generic that represent argmax and check if /// we have the ukernel that matches it shape constraint, and types. /// If we do, then we convert into iree_codegen.ukernel.argmax operation, /// that is later lowered into a call to the microkernel. static FailureOr matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) { - auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); - const char ukernelName[] = "argmax"; Value input = op.getDpsInputOperand(0)->get(); auto inputType = cast(input.getType()); Value index = op.getDpsInitOperand(1)->get(); @@ -142,41 +39,16 @@ matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) { std::string suffix; llvm::raw_string_ostream(suffix) << inputType.getElementType() << indexType.getElementType(); - FnNameAndDefAttrs fn = - getFnNameAndDefAttrs(ukernelName, suffix, rewriter, targetAttr); - if (!fn) { - return rewriter.notifyMatchFailure(op, "no ukernels on this backend"); + auto loweringConfig = getLoweringConfig(op); + if (!loweringConfig) { + return rewriter.notifyMatchFailure(op, "no lowering_config on this op"); } - - if (!hasUkernel(targetAttr, ukernelName)) { - return rewriter.notifyMatchFailure(op, "ukernel not enabled"); + IREE::GPU::UKernelSpecAttr ukernelAttr = + IREE::GPU::getUkernelSpec(loweringConfig); + if (!ukernelAttr) { + return rewriter.notifyMatchFailure(op, "no ukernel selected for this op"); } - // Currently only support argmax where parallel dims are 1. - // Tiling pipeline is also set to tile all parallel dims to 1, and - // reduction dim to be size of whole reduction problem. Which allow - // this constraint to be true for a lot of argmax variances. - // TODO: Support multi-row or grid-strided argmax ukernel. - SmallVector bounds = op.getStaticLoopRanges(); - SmallVector parallelDims; - op.getParallelDims(parallelDims); - int64_t parallelSize = 1; - for (int64_t dim : parallelDims) { - if (ShapedType::isDynamic(bounds[dim])) { - return failure(); - } - parallelSize *= bounds[dim]; - } - if (parallelSize != 1) { - return failure(); - } - auto execTarget = IREE::HAL::ExecutableTargetAttr::lookup(op); - ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op); - IREE::HAL::ExecutableObjectAttr bitcodeObject = - getUKernelBitcode(rewriter, execTarget, sourceExecutableObjects, fn.name); - if (!bitcodeObject) { - return rewriter.notifyMatchFailure(op, "no ukernel bitcode for this op"); - } Location loc = op.getLoc(); // Currently only support 1D reduction, where reduc is on fastest dim. // Tiling argmax ukernel is also set to enforce this structure. @@ -184,13 +56,9 @@ matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) { Value reductionDimSize = rewriter.create(loc, input, kReductionDim); auto genericMicroKernelOp = rewriter.create( - loc, indexType, fn.name, ValueRange{input}, index, - ValueRange{reductionDimSize}, - /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs), + loc, indexType, ukernelAttr.getName(), ValueRange{input}, index, + ValueRange{reductionDimSize}, ukernelAttr.getDefAttrs(), /*strided_outer_dims=*/rewriter.getIndexAttr(0)); - genericMicroKernelOp->setAttr( - "hal.executable.objects", - ArrayAttr::get(rewriter.getContext(), bitcodeObject)); return cast( genericMicroKernelOp.getOperation()); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index b3fdd50d4d46..2c25e02852f4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -107,7 +107,7 @@ def GPUInferMemorySpacePass : def GPULowerToUKernelsPass : Pass<"iree-codegen-gpu-lower-to-ukernels", ""> { - let summary = "Lower suitable ops to microkernels."; + let summary = "Lower suitable ops to previously-selected microkernels"; let dependentDialects = [ "::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect", "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index dc8e6a181ccf..030e6f4de497 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -31,6 +31,7 @@ iree_lit_test_suite( "gpu_greedily_distribute_to_threads.mlir", "gpu_infer_memory_space.mlir", "gpu_combine_value_barriers.mlir", + "gpu_lower_to_ukernels.mlir", "gpu_materialize_encoding_gfx908.mlir", "gpu_materialize_encoding_gfx90a.mlir", "gpu_materialize_encoding_gfx942.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 4dc0f289d3d5..6d1f540f420a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -26,6 +26,7 @@ iree_lit_test_suite( "gpu_generalize_named_ops.mlir" "gpu_greedily_distribute_to_threads.mlir" "gpu_infer_memory_space.mlir" + "gpu_lower_to_ukernels.mlir" "gpu_materialize_encoding_gfx1100.mlir" "gpu_materialize_encoding_gfx908.mlir" "gpu_materialize_encoding_gfx90a.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir new file mode 100644 index 000000000000..6a13468a1d29 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir @@ -0,0 +1,72 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s + +#config = #iree_gpu.lowering_config<{ukernel = #iree_gpu.ukernel_spec}> +func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { + hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> +} { + %c0_i64 = arith.constant 0 : i64 + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<1xi64> + %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64> + %2 = tensor.empty() : tensor<1xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32> + %4:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } + ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) + attrs = { + // The lowering_config.ukernel is what is essential to the lowering. + lowering_config = #config} { + ^bb0(%in: f32, %out: f32, %out_0: i64): + %5 = linalg.index 1 : index + %6 = arith.index_cast %5 : index to i64 + %7 = arith.maximumf %in, %out : f32 + %8 = arith.cmpf ogt, %in, %out : f32 + %9 = arith.select %8, %6, %out_0 : i64 + linalg.yield %7, %9 : f32, i64 + } -> (tensor<1xf32>, tensor<1xi64>) + return %4#1 : tensor<1xi64> +} + +//CHECK-LABEL: func @argmax_f32i64_with_selected_ukernel( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?xf32> +// CHECK-DAG: %[[C1_index:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C0_i64:.+]] = arith.constant 0 +// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[C0_i64]] +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic +// CHECK-SAME: "some_ukernel" +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: return %[[MICRO_KERNEL]] + +// ----- + +func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { + hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> +} { + %c0_i64 = arith.constant 0 : i64 + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<1xi64> + %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<1xi64>) -> tensor<1xi64> + %2 = tensor.empty() : tensor<1xf32> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<1xf32>) -> tensor<1xf32> + %4:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] + } + ins(%arg0 : tensor<1x?xf32>) outs(%3, %1 : tensor<1xf32>, tensor<1xi64>) { + ^bb0(%in: f32, %out: f32, %out_0: i64): + %5 = linalg.index 1 : index + %6 = arith.index_cast %5 : index to i64 + %7 = arith.maximumf %in, %out : f32 + %8 = arith.cmpf ogt, %in, %out : f32 + %9 = arith.select %8, %6, %out_0 : i64 + linalg.yield %7, %9 : f32, i64 + } -> (tensor<1xf32>, tensor<1xi64>) + return %4#1 : tensor<1xi64> +} + +//CHECK-LABEL: func @argmax_f32i64_without_selected_ukernel( +// CHECK-NOT: iree_codegen.ukernel.generic +// CHECK: linalg.generic diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp index 6957caf981fc..8ebfba912442 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.cpp @@ -145,4 +145,9 @@ std::optional> getPaddingList(LoweringConfigAttr config) { return getIntegerVector(array); } +IREE::GPU::UKernelSpecAttr +getUkernelSpec(IREE::GPU::LoweringConfigAttr config) { + return config.getAttributes().getAs("ukernel"); +} + } // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h index c1188b75c1eb..5bebb64a1b05 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h @@ -59,6 +59,8 @@ void setPromotedOperandList(MLIRContext *context, /// Helper to retrieve list of operand to pad. std::optional> getPaddingList(LoweringConfigAttr config); +IREE::GPU::UKernelSpecAttr getUkernelSpec(IREE::GPU::LoweringConfigAttr config); + } // namespace mlir::iree_compiler::IREE::GPU #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPULOWERINGCONFIGUTILS_H_ diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index a239af395d29..0b1e32fdc362 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -520,6 +520,25 @@ def IREEGPU_LaneIdAttr : AttrDef { + let mnemonic = "ukernel_spec"; + let summary = "An attribute specifying a ukernel that an op can lower to."; + let description = [{ + An attribute that can be applied to any operation to specify that it has + been match with a ukernel that is a legal lowering for it. + }]; + let assemblyFormat = "`<` struct(params) `>`"; + let parameters = (ins + "StringAttr":$name, + "DictionaryAttr":$def_attrs + ); +} + //===----------------------------------------------------------------------===// // GPU Pipeline Options //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index 73e039798deb..a5c1bce4beda 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -147,6 +147,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/Flow/IR", "//compiler/src/iree/compiler/Dialect/Flow/Transforms", "//compiler/src/iree/compiler/Dialect/HAL/IR", + "//compiler/src/iree/compiler/Dialect/HAL/Transforms", "//compiler/src/iree/compiler/Dialect/LinalgExt/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Transforms", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index b33641bda92e..5c206210ab30 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -190,6 +190,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Flow::Transforms iree::compiler::Dialect::HAL::IR + iree::compiler::Dialect::HAL::Transforms iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index cb22b598a94b..ee4614d7bb05 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -10,6 +10,7 @@ #include #include +#include "compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h" #include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" @@ -2042,28 +2043,15 @@ static LogicalResult setTransposeConfig(mlir::FunctionOpInterface entryPoint, /// Set the configuration for argmax when ukernels are enabled. /// Distribute all parallel dim across different workgroups, and only use single /// subgroup per workgroup. -/// -/// TODO(bjacob): This is fragile, as we can't know yet if this argmax will be -/// lowered to a ukernel. We need instead a config that works regardless of -/// ukernels. For now, we use the looser condition that the argmax ukernel is -/// enabled, a necessary but not sufficient condition for this particular op to -/// lower to the ukernel. This is good enough for now for a couple of reasons: -/// 1. Even if a argmax does not actually lower to a ukernel, this config should -/// still work. -/// 2. Ukernels are not enabled by default. static LogicalResult setArgmaxUkernelConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, linalg::GenericOp op) { // Checks if UKernels are enabled. - if (auto target = IREE::HAL::ExecutableTargetAttr::lookup(entryPoint)) { - if (!hasUkernel(target, "argmax")) { - return failure(); - } - } - - if (!target.supportsSubgroupShuffle()) + IREE::GPU::UKernelSpecAttr ukernelSpec = selectUKernelForArgmax(op); + if (!ukernelSpec) { return failure(); + } if (failed(isArgmaxOp(op))) return failure(); @@ -2094,26 +2082,35 @@ setArgmaxUkernelConfig(IREE::GPU::TargetAttr target, return failure(); } - // Tile all the parallel dimension to 1. + // Tile all the parallel dimension to 1. This is a requirement of the ukernel. SmallVector partitionedLoops = cast(op.getOperation()) .getPartitionableLoops(kNumMaxParallelDims); size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1; SmallVector workgroupTileSizes(numLoops, 1); - // Currently Argmax Ukernel let's every thread reduce reductionDim/WarpSize + // Currently Argmax Ukernel lets every thread reduce reductionDim/WarpSize // number of elements, and then it does a single step butterfly warp reduce. // Hence it expects workgroupSize to be warpSize(subgroupSize), and // reductionTileSize to be size of the reduction dim. SmallVector reductionTileSizes(op.getNumLoops(), 0); int64_t preferredSubgroupSize = target.getPreferredSubgroupSize(); reductionTileSizes[reductionDims[0]] = preferredSubgroupSize; - TileSizesListType tileSizes; - tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level - tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level std::array workgroupSize = {preferredSubgroupSize, 1, 1}; + + MLIRContext *context = op->getContext(); + Builder b(context); + SmallVector attrs; + attrs.emplace_back(StringAttr::get(context, "workgroup"), + b.getI64ArrayAttr(workgroupTileSizes)); + attrs.emplace_back(StringAttr::get(context, "reduction"), + b.getI64ArrayAttr(reductionTileSizes)); + attrs.emplace_back(StringAttr::get(context, "ukernel"), ukernelSpec); + IREE::GPU::setPromotedOperandList(context, attrs, {0, 1}); + auto configDict = DictionaryAttr::get(context, attrs); + auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); if (failed(setOpConfigAndEntryPointFnTranslation( - entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUDefault, + entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUDefault, workgroupSize))) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index f8ebe1cc0069..b6414e1b6a47 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -21,6 +21,7 @@ #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "llvm/ADT/STLForwardCompat.h" @@ -1197,6 +1198,10 @@ void buildLLVMGPUCodegenConfigurationPassPipeline( void buildLLVMGPUCodegenPassPipeline(OpPassManager &variantPassManager, bool useROCM) { + // LLVMGPUSelectLoweringStrategyPass may have created ExecutableObjectAttr. + // Hoisting them now deduplicates them and ensures that rewrite patterns don't + // need to think about explicitly copying them over to new ops. + variantPassManager.addPass(IREE::HAL::createHoistExecutableObjectsPass()); { OpPassManager &modulePassManager = variantPassManager.nest(); modulePassManager.addPass(createLowerExecutableUsingTransformDialectPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel index 113c6d56598f..66bd982ffa89 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/BUILD.bazel @@ -17,10 +17,12 @@ package( iree_compiler_cc_library( name = "Utils", srcs = [ + "LLVMGPUSelectUKernels.cpp", "LLVMGPUUtils.cpp", "PrefetchSharedMemoryCopy.cpp", ], hdrs = [ + "LLVMGPUSelectUKernels.h", "LLVMGPUUtils.h", ], deps = [ @@ -34,6 +36,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/LinalgExt/Utils", + "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", "@llvm-project//mlir:AffineDialect", @@ -42,6 +45,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:NVGPUDialect", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt index 6b66e96ded1f..98ee9404ff61 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/CMakeLists.txt @@ -14,8 +14,10 @@ iree_cc_library( NAME Utils HDRS + "LLVMGPUSelectUKernels.h" "LLVMGPUUtils.h" SRCS + "LLVMGPUSelectUKernels.cpp" "LLVMGPUUtils.cpp" "PrefetchSharedMemoryCopy.cpp" DEPS @@ -27,6 +29,7 @@ iree_cc_library( MLIRFunctionInterfaces MLIRGPUDialect MLIRIR + MLIRLinalgDialect MLIRMathDialect MLIRMemRefDialect MLIRNVGPUDialect @@ -45,6 +48,7 @@ iree_cc_library( iree::compiler::Codegen::Utils::VectorOpUtils iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Utils PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp new file mode 100644 index 000000000000..1940e8f0b102 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp @@ -0,0 +1,152 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Utils/EmbeddedDataDirectory.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir::iree_compiler { + +namespace { + +constexpr StringLiteral executableObjectsAttrName = "hal.executable.objects"; + +// Returns a ExecutableObjectAttr carrying the bitcode for the given ukernel. +// +// First tries finding the bitcode in the input `sourceExecutableObjects`, which +// must be an array of ExecutableObjectAttr's and is typically coming from a +// hal.executable.objects array attribute in the source IR, which is the +// mechanism by which source programs may provide their own ukernel bitcode. +// +// If no matching bitcode was found in `sourceExecutableObjects`, this function +// will then search in bitcode files that we have embedded as static data. +static IREE::HAL::ExecutableObjectAttr +getUKernelBitcode(MLIRContext *context, + IREE::HAL::ExecutableTargetAttr execTarget, + ArrayAttr sourceExecutableObjects, StringRef ukernelName) { + IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(execTarget); + if (!gpuTarget) { + return {}; + } + StringRef gpuArch = gpuTarget.getArch(); + std::string bitcodeFilename = llvm::formatv("{}.{}.bc", ukernelName, gpuArch); + + // Early-return if the source executable.objects already contain an object + // with the expected file name. This happens with user-provided bitcode in the + // source IR. + if (sourceExecutableObjects) { + for (Attribute a : sourceExecutableObjects) { + if (auto object = dyn_cast(a)) { + if (object.getPath() == bitcodeFilename) { + return object; + } + } + } + } + + // No user-provided bitcode, so we search our embedded bitcode files in the + // EmbeddedDataDirectory singleton. + std::optional bitcode; + EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) { + bitcode = dir.getFile(bitcodeFilename); + }); + if (!bitcode) { + return {}; + } + auto blob = HeapAsmResourceBlob::allocateAndCopyInferAlign( + ArrayRef(bitcode->data(), bitcode->size())); + auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get( + VectorType::get({static_cast(bitcode->size())}, + IntegerType::get(context, 8)), + bitcodeFilename, std::move(blob)); + return IREE::HAL::ExecutableObjectAttr::get( + context, StringAttr::get(context, bitcodeFilename), + cast(bitcodeDenseAttr)); +} + +// Walks parents ops from `op` to return the nearest hal.executable.objects +// array attribute. If the parent hal.executable.variant is reached, its objects +// attribute is returned. +// Adapted from ExecutableTargetAttr::lookup. +static ArrayAttr lookUpExecutableObjects(Operation *op) { + MLIRContext *context = op->getContext(); + auto attrId = StringAttr::get(context, executableObjectsAttrName); + while (op) { + // Take directly from the enclosing variant. + if (auto variantOp = dyn_cast(op)) { + if (std::optional objects = variantOp.getObjects()) { + return *objects; + } + } + // Take from op attributes. + if (auto attr = op->getAttrOfType(attrId)) { + return attr; + } + // Continue walk. + op = op->getParentOp(); + } + return {}; +} + +/// Returns the function name and attributes to use for a ukernel with given +/// `name` and `suffix` on the target described by `targetAttr`. +static IREE::GPU::UKernelSpecAttr +getUKernelSpec(StringRef name, StringRef suffix, MLIRContext *context, + IREE::HAL::ExecutableTargetAttr targetAttr) { + if (isROCMBackend(targetAttr)) { + auto nameAttr = StringAttr::get( + context, llvm::formatv("iree_uk_amdgpu_{}_{}", name, suffix)); + auto defsAttr = DictionaryAttr::get( + context, {{StringAttr::get(context, "vm.import.module"), + StringAttr::get(context, "rocm")}}); + return IREE::GPU::UKernelSpecAttr::get(context, nameAttr, defsAttr); + } + return {}; +} + +} // namespace + +IREE::GPU::UKernelSpecAttr selectUKernelForArgmax(linalg::GenericOp op) { + if (failed(isArgmaxOp(op))) { + return {}; + } + auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op); + const char ukernelName[] = "argmax"; + if (!hasUkernel(targetAttr, ukernelName)) { + return {}; + } + Value input = op.getDpsInputOperand(0)->get(); + auto inputType = cast(input.getType()); + Value index = op.getDpsInitOperand(1)->get(); + auto indexType = cast(index.getType()); + std::string suffix; + llvm::raw_string_ostream(suffix) + << inputType.getElementType() << indexType.getElementType(); + MLIRContext *context = op->getContext(); + IREE::GPU::UKernelSpecAttr ukernelSpec = + getUKernelSpec(ukernelName, suffix, context, targetAttr); + if (!ukernelSpec) { + return {}; + } + auto execTarget = IREE::HAL::ExecutableTargetAttr::lookup(op); + ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op); + IREE::HAL::ExecutableObjectAttr bitcodeObject = getUKernelBitcode( + context, execTarget, sourceExecutableObjects, ukernelSpec.getName()); + if (!bitcodeObject) { + return {}; + } + op->setAttr(executableObjectsAttrName, + ArrayAttr::get(context, bitcodeObject)); + return ukernelSpec; +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h new file mode 100644 index 000000000000..4ed251b36070 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.h @@ -0,0 +1,15 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +namespace mlir::iree_compiler { + +IREE::GPU::UKernelSpecAttr selectUKernelForArgmax(linalg::GenericOp op); + +} // namespace mlir::iree_compiler