Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Optimize buffer allocations construction from se::DeviceMemoryBase #19487

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions opensource_only.files
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
compiler/xla/backends/cpu/nanort/package_groups.bzl:
compiler/xla/internal/package_groups.bzl:
compiler/xla/mlir_hlo/WORKSPACE:
compiler/xla/package_groups.bzl:
Expand Down
106 changes: 106 additions & 0 deletions xla/backends/cpu/nanort/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/backends/cpu/nanort:package_groups.bzl", "xla_cpu_nanort_packages")
load("//xla/tsl:tsl.bzl", "internal_visibility")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)

# Required to load package group definitions.
xla_cpu_nanort_packages()

cc_library(
name = "nanort_client",
srcs = ["nanort_client.cc"],
hdrs = ["nanort_client.h"],
visibility = internal_visibility([
"//xla/backends/cpu/nanort:nanort_users",
]),
deps = [
":nanort_executable",
"//xla:debug_options_flags",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/pjrt:utils",
"//xla/service:compiler",
"//xla/service:dump",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service/cpu:cpu_compiler_pure",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/profiler/lib:traceme_encode",
],
)

xla_cc_test(
name = "nanort_client_test",
srcs = ["nanort_client_test.cc"],
deps = [
":nanort_client",
":nanort_executable",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "nanort_executable",
srcs = ["nanort_executable.cc"],
hdrs = ["nanort_executable.h"],
visibility = internal_visibility([
"//xla/backends/cpu/nanort:nanort_users",
]),
deps = [
"//xla:shape_util",
"//xla:util",
"//xla/backends/cpu/runtime:buffer_allocations",
"//xla/backends/cpu/runtime:thunk",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:computation_layout",
"//xla/service:executable",
"//xla/service:hlo_value",
"//xla/service/cpu:cpu_executable",
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
"@tsl//tsl/profiler/lib:traceme_encode",
],
)
21 changes: 21 additions & 0 deletions xla/backends/cpu/nanort/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Nano Client For XLA:CPU for ULTRA Low Latency Inference

Warning: **All** users must prefer the official PJRT APIs over NanoRt's.
NanoRt is only meant for a handful of users who cannot accept _any_ overhead.

Warning: **USE AT YOUR OWN RISK**. This API might be deleted at any time and XLA
CPU team does not intend to provide any backward compatibility guarantees.

This is an XLA:CPU API that resembles PjRt Client and Executable, but with a
laser focus on absolute minimal overheads at run time.

Key differences from PjRt:

1. It is focused on ultra low latency inference where each nanosecond matters.
2. It is single replica and partition and does not support any collective
operations.
3. Memory for parameters, results and temp allocation managed by the user: there
is no type that corresponds to `PjRtBuffer`, and executable uses destination
passing style to return results into user-provided memory buffers.
4. NanoRt API is unstable and does not provide any backward compatibility
guarantees.
89 changes: 89 additions & 0 deletions xla/backends/cpu/nanort/nanort_client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/cpu/nanort/nanort_client.h"

#include <memory>
#include <utility>

#include "absl/status/statusor.h"
#include "xla/backends/cpu/nanort/nanort_executable.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/pjrt/utils.h"
#include "xla/service/compiler.h"
#include "xla/service/cpu/cpu_compiler.h"
#include "xla/service/dump.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h"
#include "tsl/profiler/lib/traceme.h"
#include "tsl/profiler/lib/traceme_encode.h"

namespace xla::cpu {

using ::tsl::profiler::TraceMe;
using ::tsl::profiler::TraceMeEncode;

NanoRtClient::NanoRtClient()
: intra_op_thread_pool_(
new tsl::thread::ThreadPool(tsl::Env::Default(), tsl::ThreadOptions(),
"nanort", DefaultThreadPoolSize())) {}

absl::StatusOr<std::unique_ptr<NanoRtExecutable>> NanoRtClient::Compile(
const XlaComputation& computation) {
TraceMe trace([&] {
return TraceMeEncode("NanoRtClient::Compile",
{{"computation", computation.name()}});
});

TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());

HloModuleConfig hlo_module_config(program_shape);
hlo_module_config.set_debug_options(GetDebugOptionsFromFlags());

TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(computation.proto(), hlo_module_config));

static constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
DumpHloModuleIfEnabled(*hlo_module, kBeforeOptimizationsDumpName);

// Use default XLA compiler options.
Compiler::CompileOptions compile_options;

// Run high-level XLA CPU compiler passes.
cpu::CpuCompiler compiler;
TF_ASSIGN_OR_RETURN(hlo_module, compiler.RunHloPasses(std::move(hlo_module),
/*stream_exec=*/nullptr,
compile_options));

// Compile optimized HLO module to CPU executable.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
compiler.RunBackend(std::move(hlo_module), /*stream_exec=*/nullptr,
compile_options));

return NanoRtExecutable::Create(std::move(executable), intra_op_thread_pool_);
}

} // namespace xla::cpu
45 changes: 45 additions & 0 deletions xla/backends/cpu/nanort/nanort_client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_CPU_NANORT_NANORT_CLIENT_H_
#define XLA_BACKENDS_CPU_NANORT_NANORT_CLIENT_H_

#include <memory>

#include "absl/status/statusor.h"
#include "xla/backends/cpu/nanort/nanort_executable.h"
#include "xla/hlo/builder/xla_computation.h"
#include "tsl/platform/threadpool.h"

namespace xla::cpu {

// A client for compiling XLA programs to executables using the XLA:CPU backend.
class NanoRtClient {
public:
NanoRtClient();

// Compiles the given XLA computation to a NanoRtExecutable using the XLA:CPU
// backend.
absl::StatusOr<std::unique_ptr<NanoRtExecutable>> Compile(
const XlaComputation& computation);

private:
// Thread pool for running XLA:CPU compute tasks.
std::shared_ptr<tsl::thread::ThreadPool> intra_op_thread_pool_;
};

} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_NANORT_NANORT_CLIENT_H_
Loading
Loading