Skip to content

Commit

Permalink
[xla:cpu] kernel_api_ir_builder: expose SetKernelFunctionAttributes
Browse files Browse the repository at this point in the history
So that fusion emitters will be able to set these same
attributes. The fusion emitters are landing soon.

PiperOrigin-RevId: 721954351
  • Loading branch information
cota authored and Google-ML-Automation committed Feb 1, 2025
1 parent cbfada0 commit 547596f
Show file tree
Hide file tree
Showing 41 changed files with 437 additions and 634 deletions.
6 changes: 6 additions & 0 deletions xla/backends/cpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_cpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
Expand All @@ -25,6 +28,7 @@ td_library(

gentbl_cc_library(
name = "xla_cpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -43,6 +47,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -67,6 +72,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -15,6 +17,7 @@ package_group(

gentbl_cc_library(
name = "passes_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
Expand Down
31 changes: 17 additions & 14 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,39 +243,39 @@ absl::StatusOr<BufferAllocation::Slice> GetUniqueSlice(
return buffer_assignment->GetUniqueSlice(instruction, index);
}

} // namespace

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> arguments;
KernelApiIrBuilder::GetKernelArgumentsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> arguments;

for (HloInstruction* operand : instruction->operands()) {
for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, operand, indexed.index));
arguments.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
arguments.push_back(KernelParameter{indexed.shape, slice});
}
}
return arguments;
}

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> results;
KernelApiIrBuilder::GetKernelResultsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> results;
for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, instruction, indexed.index));
results.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
results.push_back(KernelParameter{indexed.shape, slice});
}
return results;
}

} // namespace

auto KernelApiIrBuilder::Options::FromHloModuleConfig(
const HloModuleConfig& config) -> Options {
return KernelApiIrBuilder::Options{
Expand Down Expand Up @@ -493,6 +493,11 @@ llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module,
llvm::Function* function = llvm::Function::Create(
kernel_function_ty_, llvm::GlobalValue::ExternalLinkage, name, module);

SetKernelFunctionAttributes(function);
return function;
}

void KernelApiIrBuilder::SetKernelFunctionAttributes(llvm::Function* function) {
// We use external linkage because we'll be resolving this function from the
// XLA runtime.
function->setCallingConv(llvm::CallingConv::C);
Expand All @@ -509,8 +514,6 @@ llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module,
// Always keep a frame pointer for the host kernel so we can see them in all
// performance profiling tools.
function->addFnAttr("frame-pointer", "all");

return function;
}

} // namespace xla::cpu
10 changes: 10 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ class KernelApiIrBuilder {
static std::unique_ptr<llvm::Module> CreateModule(absl::string_view name,
llvm::LLVMContext& context);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

void SetKernelFunctionAttributes(llvm::Function* function);

private:
ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder,
llvm::Value* call_frame);
Expand Down
48 changes: 48 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ limitations under the License.
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
Expand Down Expand Up @@ -75,6 +78,10 @@ class KernelApiIrBuilderTest : public HloTestBase {
[](LogicalBuffer::Color) { return /*alignment=*/1; });
}

void SetKernelFunctionAttributes(llvm::Function* function) {
kernel_api_ir_builder_.SetKernelFunctionAttributes(function);
}

llvm::LLVMContext& context() { return context_; }
std::string DumpToString() { return llvm_ir::DumpToString(&module_); }

Expand Down Expand Up @@ -294,5 +301,46 @@ TEST_F(KernelApiIrBuilderTest, MixedBuffers) {
EXPECT_TRUE(prototype.invariant_arguments.contains(0));
}

TEST_F(KernelApiIrBuilderTest, GetKernelParams) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
constexpr absl::string_view hlo_text = R"(
HloModule m
ENTRY main {
p0 = f32[2,2] parameter(0)
p1 = f32[2,2] parameter(1)
ROOT add.0 = f32[2,2] add(p0, p1)
})";

TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignment, RunBufferAssignment(*hlo));
const auto* root = hlo->entry_computation()->root_instruction();
TF_ASSERT_OK_AND_ASSIGN(auto args,
KernelApiIrBuilder::GetKernelArgumentsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(args.size(), 2);
EXPECT_THAT(args[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
EXPECT_THAT(args[1].shape.dimensions(), ::testing::ElementsAre(2, 2));
TF_ASSERT_OK_AND_ASSIGN(auto results,
KernelApiIrBuilder::GetKernelResultsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(results.size(), 1);
EXPECT_THAT(results[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
}

TEST_F(KernelApiIrBuilderTest, SetKernelFunctionAttributes) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
llvm::FunctionType* function_ty =
llvm::FunctionType::get(llvm::PointerType::getUnqual(context),
llvm::PointerType::getUnqual(context),
/*isVarArg=*/false);
llvm::Function* function = llvm::Function::Create(
function_ty, llvm::GlobalValue::ExternalLinkage, "foo", *module);
EXPECT_FALSE(function->hasFnAttribute("prefer-vector-width"));
SetKernelFunctionAttributes(function);
EXPECT_TRUE(function->hasFnAttribute("prefer-vector-width"));
}

} // namespace
} // namespace xla::cpu
22 changes: 11 additions & 11 deletions xla/backends/gpu/codegen/emitters/emitter_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,33 +574,33 @@ absl::Status EmitterBase::RunPassPipeline(
}

void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) {
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateEraseDeadFunctionsPass());
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
pm.addPass(mlir::createCSEPass());
}

void AddLoopTransformationPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(
CreateLowerXlaGpuToScfPass(device.threads_per_warp()));
emitters::CreateLowerXlaToScfPass(device.threads_per_warp()));
pm.addNestedPass<FuncOp>(CreateFuseLoopsPass());
pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
// CSE after inlining because inlining can introduce duplicates.
pm.addPass(mlir::createCSEPass());
}));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
pm.addNestedPass<FuncOp>(emitters::CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateLowerXlaLoopsToScfPass());
pm.addPass(mlir::mhlo::createConvertToSignlessPass());
pm.addPass(CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreateFlattenTensorsPass());
// We need LICM before unswitching loops, because our loop unswitcher only
// detects for loops with a single if inside them.
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addNestedPass<FuncOp>(CreateUnswitchLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateUnswitchLoopsPass());
// We need LICM again after unswitching, because that can introduce new
// opportunities for LICM. This would not be necessary if LICM also moved
// instructions over ifs.
Expand All @@ -613,17 +613,17 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm,

void AddLoweringPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(CreateConvertPureCallOpsPass());
pm.addNestedPass<FuncOp>(emitters::CreateConvertPureCallOpsPass());
pm.addPass(emitters::CreateLowerTensorsPass(device));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(CreateMergePointersToSameSlicePass());
pm.addPass(emitters::CreateMergePointersToSameSlicePass());

// LowerTensors creates new affine.apply ops. Fold and CSE them so
// simplify-affine has maximally folded expressions to work with.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addPass(CreateSimplifyAffinePass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(emitters::CreateSimplifyAffinePass());
pm.addPass(CreateConvertIndexTypePass());
// simplify-affine lowers most affine.apply ops, but if it can't prove a
// division or modulo is unsigned, affine.apply ops will remain.
Expand Down
7 changes: 7 additions & 0 deletions xla/backends/gpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_gpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"//xla/codegen/emitters/ir:xla_td_files",
Expand All @@ -30,6 +33,7 @@ td_library(

gentbl_cc_library(
name = "xla_gpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -48,6 +52,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -66,6 +71,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_attrs_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down Expand Up @@ -98,6 +104,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce

add {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
11 changes: 0 additions & 11 deletions xla/backends/gpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,8 @@ cc_library(
srcs = [
"convert_float_nvidia.cc",
"convert_index_type.cc",
"convert_xla_gpu_pure_call_ops.cc",
"erase_dead_functions.cc",
"fuse_loops.cc",
"lower_xla_gpu_to_scf.cc",
"merge_pointers_to_same_slice.cc",
"optimize_loops.cc",
"peel_loops.cc",
"propagate_slice_indices.cc",
"simplify_affine.cc",
"simplify_arith.cc",
"unswitch_loops.cc",
"vectorize_loads_stores.cc",
],
hdrs = ["passes.h"],
Expand All @@ -61,7 +52,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
"//xla/codegen/emitters:elemental_hlo_to_mlir",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/emitters/transforms:atomic_rmw_utils",
"//xla/hlo/analysis:indexing_analysis",
Expand All @@ -74,7 +64,6 @@ cc_library(
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
Loading

0 comments on commit 547596f

Please sign in to comment.