Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] kernel_api_ir_builder: expose SetKernelFunctionAttributes #22195

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions xla/backends/cpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_cpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
Expand All @@ -25,6 +28,7 @@ td_library(

gentbl_cc_library(
name = "xla_cpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -43,6 +47,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -67,6 +72,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_cpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -15,6 +17,7 @@ package_group(

gentbl_cc_library(
name = "passes_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
Expand Down
31 changes: 17 additions & 14 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,39 +243,39 @@ absl::StatusOr<BufferAllocation::Slice> GetUniqueSlice(
return buffer_assignment->GetUniqueSlice(instruction, index);
}

} // namespace

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> arguments;
KernelApiIrBuilder::GetKernelArgumentsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> arguments;

for (HloInstruction* operand : instruction->operands()) {
for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, operand, indexed.index));
arguments.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
arguments.push_back(KernelParameter{indexed.shape, slice});
}
}
return arguments;
}

absl::StatusOr<std::vector<KernelApiIrBuilder::KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelApiIrBuilder::KernelParameter> results;
KernelApiIrBuilder::GetKernelResultsParameters(
const HloInstruction* instruction,
const BufferAssignment* buffer_assignment) {
std::vector<KernelParameter> results;
for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
GetUniqueSlice(buffer_assignment, instruction, indexed.index));
results.push_back(
KernelApiIrBuilder::KernelParameter{indexed.shape, slice});
results.push_back(KernelParameter{indexed.shape, slice});
}
return results;
}

} // namespace

auto KernelApiIrBuilder::Options::FromHloModuleConfig(
const HloModuleConfig& config) -> Options {
return KernelApiIrBuilder::Options{
Expand Down Expand Up @@ -493,6 +493,11 @@ llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module,
llvm::Function* function = llvm::Function::Create(
kernel_function_ty_, llvm::GlobalValue::ExternalLinkage, name, module);

SetKernelFunctionAttributes(function);
return function;
}

void KernelApiIrBuilder::SetKernelFunctionAttributes(llvm::Function* function) {
// We use external linkage because we'll be resolving this function from the
// XLA runtime.
function->setCallingConv(llvm::CallingConv::C);
Expand All @@ -509,8 +514,6 @@ llvm::Function* KernelApiIrBuilder::EmitKernelFunction(llvm::Module& module,
// Always keep a frame pointer for the host kernel so we can see them in all
// performance profiling tools.
function->addFnAttr("frame-pointer", "all");

return function;
}

} // namespace xla::cpu
10 changes: 10 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ class KernelApiIrBuilder {
static std::unique_ptr<llvm::Module> CreateModule(absl::string_view name,
llvm::LLVMContext& context);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelArgumentsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

static absl::StatusOr<std::vector<KernelParameter>>
GetKernelResultsParameters(const HloInstruction* instruction,
const BufferAssignment* buffer_assignment);

void SetKernelFunctionAttributes(llvm::Function* function);

private:
ThreadDims EmitKernelThreadDims(llvm::IRBuilderBase& builder,
llvm::Value* call_frame);
Expand Down
48 changes: 48 additions & 0 deletions xla/backends/cpu/codegen/kernel_api_ir_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ limitations under the License.
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Type.h"
Expand Down Expand Up @@ -75,6 +78,10 @@ class KernelApiIrBuilderTest : public HloTestBase {
[](LogicalBuffer::Color) { return /*alignment=*/1; });
}

void SetKernelFunctionAttributes(llvm::Function* function) {
kernel_api_ir_builder_.SetKernelFunctionAttributes(function);
}

llvm::LLVMContext& context() { return context_; }
std::string DumpToString() { return llvm_ir::DumpToString(&module_); }

Expand Down Expand Up @@ -294,5 +301,46 @@ TEST_F(KernelApiIrBuilderTest, MixedBuffers) {
EXPECT_TRUE(prototype.invariant_arguments.contains(0));
}

TEST_F(KernelApiIrBuilderTest, GetKernelParams) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
constexpr absl::string_view hlo_text = R"(
HloModule m
ENTRY main {
p0 = f32[2,2] parameter(0)
p1 = f32[2,2] parameter(1)
ROOT add.0 = f32[2,2] add(p0, p1)
})";

TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(auto buffer_assignment, RunBufferAssignment(*hlo));
const auto* root = hlo->entry_computation()->root_instruction();
TF_ASSERT_OK_AND_ASSIGN(auto args,
KernelApiIrBuilder::GetKernelArgumentsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(args.size(), 2);
EXPECT_THAT(args[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
EXPECT_THAT(args[1].shape.dimensions(), ::testing::ElementsAre(2, 2));
TF_ASSERT_OK_AND_ASSIGN(auto results,
KernelApiIrBuilder::GetKernelResultsParameters(
root, buffer_assignment.get()));
EXPECT_EQ(results.size(), 1);
EXPECT_THAT(results[0].shape.dimensions(), ::testing::ElementsAre(2, 2));
}

TEST_F(KernelApiIrBuilderTest, SetKernelFunctionAttributes) {
llvm::LLVMContext context;
auto module = std::make_unique<llvm::Module>("test", context);
llvm::FunctionType* function_ty =
llvm::FunctionType::get(llvm::PointerType::getUnqual(context),
llvm::PointerType::getUnqual(context),
/*isVarArg=*/false);
llvm::Function* function = llvm::Function::Create(
function_ty, llvm::GlobalValue::ExternalLinkage, "foo", *module);
EXPECT_FALSE(function->hasFnAttribute("prefer-vector-width"));
SetKernelFunctionAttributes(function);
EXPECT_TRUE(function->hasFnAttribute("prefer-vector-width"));
}

} // namespace
} // namespace xla::cpu
22 changes: 11 additions & 11 deletions xla/backends/gpu/codegen/emitters/emitter_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,33 +574,33 @@ absl::Status EmitterBase::RunPassPipeline(
}

void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) {
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateEraseDeadFunctionsPass());
pm.addPass(emitters::CreateEraseDeadFunctionsPass());
pm.addPass(mlir::createCSEPass());
}

void AddLoopTransformationPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(
CreateLowerXlaGpuToScfPass(device.threads_per_warp()));
emitters::CreateLowerXlaToScfPass(device.threads_per_warp()));
pm.addNestedPass<FuncOp>(CreateFuseLoopsPass());
pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) {
// CSE after inlining because inlining can introduce duplicates.
pm.addPass(mlir::createCSEPass());
}));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(CreateLowerXlaGpuLoopsToScfPass());
pm.addNestedPass<FuncOp>(emitters::CreatePeelLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateLowerXlaLoopsToScfPass());
pm.addPass(mlir::mhlo::createConvertToSignlessPass());
pm.addPass(CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreatePropagateSliceIndicesPass());
pm.addPass(emitters::CreateFlattenTensorsPass());
// We need LICM before unswitching loops, because our loop unswitcher only
// detects for loops with a single if inside them.
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addNestedPass<FuncOp>(CreateUnswitchLoopsPass());
pm.addNestedPass<FuncOp>(emitters::CreateUnswitchLoopsPass());
// We need LICM again after unswitching, because that can introduce new
// opportunities for LICM. This would not be necessary if LICM also moved
// instructions over ifs.
Expand All @@ -613,17 +613,17 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm,

void AddLoweringPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
pm.addNestedPass<FuncOp>(CreateConvertPureCallOpsPass());
pm.addNestedPass<FuncOp>(emitters::CreateConvertPureCallOpsPass());
pm.addPass(emitters::CreateLowerTensorsPass(device));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(CreateMergePointersToSameSlicePass());
pm.addPass(emitters::CreateMergePointersToSameSlicePass());

// LowerTensors creates new affine.apply ops. Fold and CSE them so
// simplify-affine has maximally folded expressions to work with.
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<FuncOp>(CreateSimplifyArithPass());
pm.addPass(CreateSimplifyAffinePass());
pm.addNestedPass<FuncOp>(emitters::CreateSimplifyArithPass());
pm.addPass(emitters::CreateSimplifyAffinePass());
pm.addPass(CreateConvertIndexTypePass());
// simplify-affine lowers most affine.apply ops, but if it can't prove a
// division or modulo is unsigned, affine.apply ops will remain.
Expand Down
7 changes: 7 additions & 0 deletions xla/backends/gpu/codegen/emitters/ir/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load("//tensorflow:tensorflow.google.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -16,6 +18,7 @@ package_group(
td_library(
name = "xla_gpu_td_files",
srcs = glob(["*.td"]),
compatible_with = get_compatible_with_portable(),
includes = ["."],
deps = [
"//xla/codegen/emitters/ir:xla_td_files",
Expand All @@ -30,6 +33,7 @@ td_library(

gentbl_cc_library(
name = "xla_gpu_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -48,6 +52,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand All @@ -66,6 +71,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_attrs_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down Expand Up @@ -98,6 +104,7 @@ gentbl_cc_library(

gentbl_cc_library(
name = "xla_gpu_types_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce

add {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s --dump-input=always
// RUN: fusion_to_mlir %s | emitters_opt -cse -xla-simplify-arith -canonicalize | FileCheck %s --dump-input=always

add {
%p0 = f32[] parameter(0)
Expand Down
11 changes: 0 additions & 11 deletions xla/backends/gpu/codegen/emitters/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,8 @@ cc_library(
srcs = [
"convert_float_nvidia.cc",
"convert_index_type.cc",
"convert_xla_gpu_pure_call_ops.cc",
"erase_dead_functions.cc",
"fuse_loops.cc",
"lower_xla_gpu_to_scf.cc",
"merge_pointers_to_same_slice.cc",
"optimize_loops.cc",
"peel_loops.cc",
"propagate_slice_indices.cc",
"simplify_affine.cc",
"simplify_arith.cc",
"unswitch_loops.cc",
"vectorize_loads_stores.cc",
],
hdrs = ["passes.h"],
Expand All @@ -61,7 +52,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
"//xla/codegen/emitters:elemental_hlo_to_mlir",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/emitters/transforms:atomic_rmw_utils",
"//xla/hlo/analysis:indexing_analysis",
Expand All @@ -74,7 +64,6 @@ cc_library(
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
Loading
Loading