Skip to content

Commit

Permalink
[LLVMGPU] Deprecate the matmul simt pipeline (iree-org#19335)
Browse files Browse the repository at this point in the history
This patch deprecates the matmul simt pipeline and replaces all its uses
with the newer TileAndFuse pipeline.
  • Loading branch information
pashu123 authored Dec 16, 2024
1 parent fdf4ae6 commit 6ff00a8
Show file tree
Hide file tree
Showing 17 changed files with 108 additions and 259 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
]>
hal.executable private @main_dispatch_0 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>, workgroup_size = [64 : index, 16 : index, 1 : index]} {
hal.executable.export public @main_dispatch_0_matmul_transpose_b_32000x32000x4096_f16 ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index, translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>, workgroup_size = [64 : index, 16 : index, 1 : index]} {
^bb0(%arg0: !hal.device):
%c250 = arith.constant 250 : index
%c500 = arith.constant 500 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,24 @@ def LLVMGPU_SimpleDistribute
: I32EnumAttrCase<"LLVMGPUDistribute", 102>;
def LLVMGPU_Vectorize
: I32EnumAttrCase<"LLVMGPUVectorize", 103>;
def LLVMGPU_MatmulSimt
: I32EnumAttrCase<"LLVMGPUMatmulSimt", 104>;
def LLVMGPU_MatmulTensorCore
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 105>;
: I32EnumAttrCase<"LLVMGPUMatmulTensorCore", 104>;
def LLVMGPU_TransposeSharedMem
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 106>;
: I32EnumAttrCase<"LLVMGPUTransposeSharedMem", 105>;
def LLVMGPU_WarpReduction
: I32EnumAttrCase<"LLVMGPUWarpReduction", 107>;
: I32EnumAttrCase<"LLVMGPUWarpReduction", 106>;
def LLVMGPU_PackUnPack
: I32EnumAttrCase<"LLVMGPUPackUnPack", 108>;
: I32EnumAttrCase<"LLVMGPUPackUnPack", 107>;
def LLVMGPU_MatmulTensorCoreMmaSync
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 109>;
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 108>;
def LLVMGPU_VectorDistribute
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 110>;
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 109>;
def LLVMGPU_PadAndVectorDistribute
: I32EnumAttrCase<"LLVMGPUPadAndVectorDistribute", 111>;
: I32EnumAttrCase<"LLVMGPUPadAndVectorDistribute", 110>;
def LLVMGPU_WinogradVectorize
: I32EnumAttrCase<"LLVMGPUWinogradVectorize", 112>;
: I32EnumAttrCase<"LLVMGPUWinogradVectorize", 111>;
def LLVMGPU_TileAndFuse
: I32EnumAttrCase<"LLVMGPUTileAndFuse", 113>;
: I32EnumAttrCase<"LLVMGPUTileAndFuse", 112>;

def SPIRV_BaseLowering
: I32EnumAttrCase<"SPIRVBaseLowering", 200>;
Expand Down Expand Up @@ -98,7 +96,7 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr<

// LLVMGPU CodeGen pipelines
LLVMGPU_Default, LLVMGPU_BaseLowering, LLVMGPU_SimpleDistribute,
LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
LLVMGPU_Vectorize, LLVMGPU_MatmulTensorCore,
LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction, LLVMGPU_PackUnPack,
LLVMGPU_MatmulTensorCoreMmaSync, LLVMGPU_VectorDistribute,
LLVMGPU_PadAndVectorDistribute, LLVMGPU_WinogradVectorize,
Expand Down
70 changes: 63 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,9 +1295,11 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
CodeGenPipeline pipeline) {
TileSizesListType tileSizes;
unsigned numParallelLoops = op.getNumParallelLoops();
SmallVector<int64_t> workgroupTileSizes(numParallelLoops - 2, 1);
workgroupTileSizes.append({tileX, tileY});
workgroupTileSizes.append(op.getNumReductionLoops(), tileK);
unsigned numReductionLoops = op.getNumReductionLoops();
SmallVector<int64_t> workgroupTileSizes(
numParallelLoops + numReductionLoops, 1);
workgroupTileSizes[numParallelLoops - 2] = tileX;
workgroupTileSizes[numParallelLoops - 1] = tileY;

SmallVector<unsigned> partitionedLoops =
cast<PartitionableLoopsInterface>(op.getOperation())
Expand All @@ -1311,11 +1313,65 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
}
}

tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level.
std::optional<int64_t> subgroupSize = std::nullopt;
if (!subgroupSizes.empty())
subgroupSize = subgroupSizes.front();

// For the LLVMGPUTileAndFuse pipeline, we need to split tile sizes
// for workgroup, thread, and reduction.
if (pipeline == CodeGenPipeline::LLVMGPUTileAndFuse) {

auto context = op.getContext();
Builder b(context);
SmallVector<NamedAttribute, 1> attrs;

SmallVector<int64_t> threadTileSizes(numParallelLoops + numReductionLoops,
0);
std::fill(threadTileSizes.begin(),
threadTileSizes.begin() + numParallelLoops, 1);

threadTileSizes[numParallelLoops - 2] =
(tileX / workgroupSize[0]) < 1 ? 1 : (tileX / workgroupSize[0]);
threadTileSizes[numParallelLoops - 1] =
(tileY / workgroupSize[1]) < 1 ? 1 : (tileY / workgroupSize[1]);

SmallVector<int64_t> reductionTileSizes(
numParallelLoops + numReductionLoops, 0);
reductionTileSizes[numParallelLoops + numReductionLoops - 1] = tileK;

attrs.emplace_back(b.getStringAttr("workgroup"),
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(b.getStringAttr("thread"),
b.getI64ArrayAttr(threadTileSizes));
attrs.emplace_back(b.getStringAttr("reduction"),
b.getI64ArrayAttr(reductionTileSizes));

// Promote operands to use shared memory for LHS and RHS.
IREE::GPU::setPromotedOperandList(context, attrs, {0, 1});
auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, configDict);
SmallVector<NamedAttribute, 1> pipelineAttrs;
auto pipelineOptions = IREE::GPU::GPUPipelineOptionsAttr::get(
context, /*prefetchSharedMemory=*/false,
/*no_reduce_shared_memory_bank_conflicts=*/true,
/*use_igemm_convolution=*/false,
/*reorder_workgroups_strategy=*/std::nullopt);
pipelineAttrs.emplace_back(
b.getStringAttr(IREE::GPU::GPUPipelineOptionsAttr::getDictKeyName()),
pipelineOptions);
auto pipelineConfig = b.getDictionaryAttr(pipelineAttrs);

return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, loweringConfig, pipeline, workgroupSize, subgroupSize,
pipelineConfig);
}

// Other pipeline (MatmulTensorCore) expect the reduction tile size to be in
// the same list.
workgroupTileSizes[numParallelLoops + numReductionLoops - 1] = tileK;
tileSizes.emplace_back(std::move(workgroupTileSizes));

return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, tileSizes, pipeline, workgroupSize, subgroupSize,
getSoftwarePipeliningAttrDict(op->getContext(), softwarePipelineDepth,
Expand Down Expand Up @@ -1390,7 +1446,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
return setMatmulConfig(
sizeN, sizeM, 4, {sizeM, sizeN, 1},
target.getWgp().getSubgroupSizeChoices().asArrayRef(),
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt);
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse);
}

// SIMT matmul case. Query the best configuration.
Expand All @@ -1404,7 +1460,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
config.tileSize[0], config.tileSize[1], config.tileSize[2],
config.workgroupSize,
target.getWgp().getSubgroupSizeChoices().asArrayRef(),
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUMatmulSimt);
softwarePipelineDepthSimt, CodeGenPipeline::LLVMGPUTileAndFuse);
}
}
}
Expand All @@ -1429,7 +1485,7 @@ static LogicalResult setContractConfig(IREE::GPU::TargetAttr target,
return setMatmulConfig(tileX, tileY, tileK, workgroupSize,
target.getWgp().getSubgroupSizeChoices().asArrayRef(),
softwarePipelineDepthSimt,
CodeGenPipeline::LLVMGPUMatmulSimt);
CodeGenPipeline::LLVMGPUTileAndFuse);
}

//====---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() {
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUWinogradVectorize:
addGPUWinogradVectorizePassPipeline(pipeline);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulSimt:
addGPUMatmulSimtPassPipeline(pipeline, pipelineOptions);
break;
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUMatmulTensorCore: {
FailureOr<int64_t> maybeDepth =
getSoftwarePipelineDepth(translationInfo.getConfiguration());
Expand Down
66 changes: 0 additions & 66 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,72 +526,6 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
}

//===---------------------------------------------------------------------===//
// MatmulSIMT
//===---------------------------------------------------------------------===//

void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass());
funcPassManager.addPass(createGPUTensorAlloc());
funcPassManager.addPass(createGPUTensorTilePass());

// Linalg -> vector
addGPUVectorizationPasses(funcPassManager);

// tensor to memref
addBufferizePasses(funcPassManager);

// distribute foreach threads
funcPassManager.addPass(createGPUDistributePass());

funcPassManager.addPass(createMemrefCopyToLinalgPass());
funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

if (options.enableReduceSharedMemoryBankConflicts) {
funcPassManager.addPass(createGPUReduceBankConflictsPass());
}

ReorderWorkgroupsStrategy reorderStrategy =
getReorderWorkgroupsStrategy(options.reorderStrategy);
funcPassManager.addPass(
createReorderWorkgroups(reorderStrategy, canReorderWorkgroups));

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Even though we vectorize before bufferization we are not able to hoist
// accumulator load/store out of the K loop until distribution. This is
// because we materialize the fill and the matmul in two different scf.forall
// regions, when they should be in the same scf.forall. Newer pipelines
// like TileAndFuse don't have this problem, because they coalesce these
// scf.forall regions into a single scf.forall.
//
// Therefore we still rely on buffer level transformations for transfer ops
// hoisting and store to load forwarding. This relies on shacky alias
// analysis and we need to move this to tensor level once we have better
// abstractions.
funcPassManager.addPass(createOptimizeVectorTransferPass());

// Hoist loop invariant code to avoid pipelining it.
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
// Pipeline memory operations.
funcPassManager.addPass(createGPUPipeliningPass());
}

//===---------------------------------------------------------------------===//
// Matmul Tensor Core
//===---------------------------------------------------------------------===//
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ using IREE::GPU::GPUPipelineOptions;
// LLVMGPU Backend Pass Pipelines
//----------------------------------------------------------------------------//

/// Lowering using SIMT CUDA core operations.
void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options);

/// Lowering using mma.sync Tensor Core operations.
void addGPUMatmulTensorCoreMmaSyncPassPipeline(
OpPassManager &funcPassManager, const GPUPipelineOptions &options,
Expand Down
11 changes: 1 addition & 10 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Verifiers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ getInstructionShape(Operation *op, CodeGenPipeline pipeline,
Type inputElementType,
SmallVector<int64_t> &instructionShape) {
switch (pipeline) {
case CodeGenPipeline::LLVMGPUMatmulSimt:
// SIMT Pipeline / CUDA Cores
instructionShape = {1, 1, 1};
break;
case CodeGenPipeline::LLVMGPUMatmulTensorCore:
// Tensor Core Pipeline / WMMA API
if (inputElementType.isF16() || inputElementType.isBF16()) {
Expand Down Expand Up @@ -81,8 +77,7 @@ verifyGPUMatmulPipeline(Operation *op,
ArrayRef<int64_t> workgroupSize) {
// This verifier only applies to matmul.
CodeGenPipeline pipeline = translationInfo.getDispatchLoweringPassPipeline();
if (pipeline != CodeGenPipeline::LLVMGPUMatmulSimt &&
pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCore &&
if (pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCore &&
pipeline != CodeGenPipeline::LLVMGPUMatmulTensorCoreMmaSync) {
return success();
}
Expand Down Expand Up @@ -180,10 +175,6 @@ verifyGPUMatmulPipeline(Operation *op,
<< pipelineName;
}

// Return success for SIMT/CUDA cores.
if (pipeline == CodeGenPipeline::LLVMGPUMatmulSimt)
return success();

//
// Additional verification Tensor Core pipelines.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,11 @@ func.func @not_vmt() {
return
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 128, 8]{{\]}}>
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [32, 1, 1] subgroup_size = 64, {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>}>
// CHECK: func.func @not_vmt()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1], reduction = [0, 0, 8], thread = [1, 128, 0], workgroup = [1, 128, 1]}>

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ func.func @matmul(%lhs: tensor<4x4xf32>, %rhs: tensor<4x4xf32>) -> tensor<4x4xf3
return %result : tensor<4x4xf32>
}

// CHECK: %2 = linalg.matmul {lowering_config = #config, root_op} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %2 = linalg.matmul {lowering_config = #{{.*}}, root_op} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#map = affine_map<()[s0] -> (s0 * 2)>
#map1 = affine_map<()[s0] -> (s0 * 256)>
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [64, 1, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
func.func @dot_dispatch_0() attributes {translation_info = #translation} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -79,7 +79,7 @@ func.func @dot_dispatch_0() attributes {translation_info = #translation} {
#map2 = affine_map<(d0, d1, d2)[s0] -> (d0 * 32768 + s0 + d1 * 1024 + d2)>
#map3 = affine_map<(d0, d1, d2)[s0] -> (d0 * 65536 + s0 + d1 * 64 + d2)>
#map4 = affine_map<(d0, d1, d2)[s0] -> (d0 * 2048 + s0 + d1 * 64 + d2)>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [8, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [8, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
func.func @batch_matmul_func() attributes {translation_info = #translation} {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
Expand Down Expand Up @@ -148,7 +148,7 @@ func.func @batch_matmul_func() attributes {translation_info = #translation} {
#map = affine_map<()[s0] -> (s0 * 2)>
#map1 = affine_map<()[s0] -> (s0 * 32)>
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
func.func @dot_dispatch_0() attributes {translation_info = #translation} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -312,7 +312,7 @@ module {
#hal.pipeline.binding<storage_buffer>
]>
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 2, 256, 4]]>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUMatmulSimt workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 8, 1], {pipeline_depth = 0 : i64, store_stage = 1 : i64}>
#map = affine_map<()[s0] -> (s0 * 2)>
#map1 = affine_map<()[s0] -> (s0 * 256)>
#map2 = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
Expand Down
Loading

0 comments on commit 6ff00a8

Please sign in to comment.