Skip to content

Commit

Permalink
[xla:cpu] introduce FusionWrapper pass
Browse files Browse the repository at this point in the history
This pass wraps scatter ops with a fusion, so that the fusion emitter
will be able to do its thing.

PiperOrigin-RevId: 721954335
  • Loading branch information
cota authored and Google-ML-Automation committed Feb 2, 2025
1 parent 492a921 commit 6960954
Show file tree
Hide file tree
Showing 48 changed files with 659 additions and 637 deletions.
6 changes: 6 additions & 0 deletions xla/backends/cpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_cpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
Expand All @@ -25,6 +28,7 @@ td_library(

gentbl_cc_library(
name = "xla_cpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -43,6 +47,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -67,6 +72,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -15,6 +17,7 @@ package_group(

gentbl_cc_library(
name = "passes_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
Expand Down
31 changes: 17 additions & 14 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,39 +243,39 @@ absl::StatusOr<BufferAllocation::Slice> GetUniqueSlice(
return buffer_assignment->GetUniqueSlice(instruction, index);
}

} // namespace

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> arguments;
KernelApiIrBuilder::GetKernelArgumentsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> arguments;

for (HloInstruction* operand : instruction->operands()) {
for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, operand, indexed.index));
arguments.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
arguments.push_back(KernelParameter{indexed.shape, slice});
}
}
return arguments;
}

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> results;
KernelApiIrBuilder::GetKernelResultsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> results;
for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, instruction, indexed.index));
results.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
results.push_back(KernelParameter{indexed.shape, slice});
}
return results;
}

} // namespace

auto KernelApiIrBuilder::Options::FromHloModuleConfig(
const HloModuleConfig& config) -> Options {
return KernelApiIrBuilder::Options{
Expand Down Expand Up @@ -493,6 +493,11 @@ llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module,
llvm::Function* function = llvm::Function::Create(
kernel_function_ty_, llvm::GlobalValue::ExternalLinkage, name, module);

SetKernelFunctionAttributes(function);
return function;
}

void KernelApiIrBuilder::SetKernelFunctionAttributes(llvm::Function* function) {
// We use external linkage because we'll be resolving this function from the
// XLA runtime.
function->setCallingConv(llvm::CallingConv::C);
Expand All @@ -509,8 +514,6 @@ llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module,
// Always keep a frame pointer for the host kernel so we can see them in all
// performance profiling tools.
function->addFnAttr("frame-pointer", "all");

return function;
}

} // namespace xla::cpu
10 changes: 10 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ class KernelApiIrBuilder {
static std::unique_ptr<llvm::Module> CreateModule(absl::string_view name,
llvm::LLVMContext& context);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

void SetKernelFunctionAttributes(llvm::Function* function);

private:
ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder,
llvm::Value* call_frame);
Expand Down
48 changes: 48 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ limitations under the License.
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
Expand Down Expand Up @@ -75,6 +78,10 @@ class KernelApiIrBuilderTest : public HloTestBase {
[](LogicalBuffer::Color) { return /*alignment=*/1; });
}

void SetKernelFunctionAttributes(llvm::Function* function) {
kernel_api_ir_builder_.SetKernelFunctionAttributes(function);
}

llvm::LLVMContext& context() { return context_; }
std::string DumpToString() { return llvm_ir::DumpToString(&module_); }

Expand Down Expand Up @@ -294,5 +301,46 @@ TEST_F(KernelApiIrBuilderTest, MixedBuffers) {
EXPECT_TRUE(prototype.invariant_arguments.contains(0));
}

TEST_F(KernelApiIrBuilderTest, GetKernelParams) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
constexpr absl::string_view hlo_text = R"(
HloModule m
ENTRY main {
p0 = f32[2,2] parameter(0)
p1 = f32[2,2] parameter(1)
ROOT add.0 = f32[2,2] add(p0, p1)
})";

TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignment, RunBufferAssignment(*hlo));
const auto* root = hlo->entry_computation()->root_instruction();
TF_ASSERT_OK_AND_ASSIGN(auto args,
KernelApiIrBuilder::GetKernelArgumentsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(args.size(), 2);
EXPECT_THAT(args[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
EXPECT_THAT(args[1].shape.dimensions(), ::testing::ElementsAre(2, 2));
TF_ASSERT_OK_AND_ASSIGN(auto results,
KernelApiIrBuilder::GetKernelResultsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(results.size(), 1);
EXPECT_THAT(results[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
}

TEST_F(KernelApiIrBuilderTest, SetKernelFunctionAttributes) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
llvm::FunctionType* function_ty =
llvm::FunctionType::get(llvm::PointerType::getUnqual(context),
llvm::PointerType::getUnqual(context),
/*isVarArg=*/false);
llvm::Function* function = llvm::Function::Create(
function_ty, llvm::GlobalValue::ExternalLinkage, "foo", *module);
EXPECT_FALSE(function->hasFnAttribute("prefer-vector-width"));
SetKernelFunctionAttributes(function);
EXPECT_TRUE(function->hasFnAttribute("prefer-vector-width"));
}

} // namespace
} // namespace xla::cpu
22 changes: 11 additions & 11 deletions xla/backends/gpu/codegen/emitters/emitter_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,33 +574,33 @@ absl::Status EmitterBase::RunPassPipeline(
}

void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) {
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateEraseDeadFunctionsPass());
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
pm.addPass(mlir::createCSEPass());
}

void AddLoopTransformationPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(
CreateLowerXlaGpuToScfPass(device.threads_per_warp()));
emitters::CreateLowerXlaToScfPass(device.threads_per_warp()));
pm.addNestedPass<FuncOp>(CreateFuseLoopsPass());
pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
// CSE after inlining because inlining can introduce duplicates.
pm.addPass(mlir::createCSEPass());
}));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
pm.addNestedPass<FuncOp>(emitters::CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateLowerXlaLoopsToScfPass());
pm.addPass(mlir::mhlo::createConvertToSignlessPass());
pm.addPass(CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreateFlattenTensorsPass());
// We need LICM before unswitching loops, because our loop unswitcher only
// detects for loops with a single if inside them.
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addNestedPass<FuncOp>(CreateUnswitchLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateUnswitchLoopsPass());
// We need LICM again after unswitching, because that can introduce new
// opportunities for LICM. This would not be necessary if LICM also moved
// instructions over ifs.
Expand All @@ -613,17 +613,17 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm,

void AddLoweringPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(CreateConvertPureCallOpsPass());
pm.addNestedPass<FuncOp>(emitters::CreateConvertPureCallOpsPass());
pm.addPass(emitters::CreateLowerTensorsPass(device));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(CreateMergePointersToSameSlicePass());
pm.addPass(emitters::CreateMergePointersToSameSlicePass());

// LowerTensors creates new affine.apply ops. Fold and CSE them so
// simplify-affine has maximally folded expressions to work with.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addPass(CreateSimplifyAffinePass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(emitters::CreateSimplifyAffinePass());
pm.addPass(CreateConvertIndexTypePass());
// simplify-affine lowers most affine.apply ops, but if it can't prove a
// division or modulo is unsigned, affine.apply ops will remain.
Expand Down
7 changes: 7 additions & 0 deletions xla/backends/gpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_gpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"//xla/codegen/emitters/ir:xla_td_files",
Expand All @@ -30,6 +33,7 @@ td_library(

gentbl_cc_library(
name = "xla_gpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -48,6 +52,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -66,6 +71,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_attrs_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down Expand Up @@ -98,6 +104,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce

add {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
11 changes: 0 additions & 11 deletions xla/backends/gpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,8 @@ cc_library(
srcs = [
"convert_float_nvidia.cc",
"convert_index_type.cc",
"convert_xla_gpu_pure_call_ops.cc",
"erase_dead_functions.cc",
"fuse_loops.cc",
"lower_xla_gpu_to_scf.cc",
"merge_pointers_to_same_slice.cc",
"optimize_loops.cc",
"peel_loops.cc",
"propagate_slice_indices.cc",
"simplify_affine.cc",
"simplify_arith.cc",
"unswitch_loops.cc",
"vectorize_loads_stores.cc",
],
hdrs = ["passes.h"],
Expand All @@ -61,7 +52,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
"//xla/codegen/emitters:elemental_hlo_to_mlir",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/emitters/transforms:atomic_rmw_utils",
"//xla/hlo/analysis:indexing_analysis",
Expand All @@ -74,7 +64,6 @@ cc_library(
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
Loading

0 comments on commit 6960954

Please sign in to comment.