From d88d0a7397e843665b645e7d0092349ed0cc3fa6 Mon Sep 17 00:00:00 2001 From: Twice Date: Sat, 7 Dec 2024 06:51:13 +0800 Subject: [PATCH] Fix device list of loaded executable in PJRT plugin for multiple GPUs (#19369) It closes #19366, and blocks #19279. After this PR, `ClientOptions::Compile` will first check the device assignment in the compile options, and then return the corresponding device list with the loaded executable. To achieve this, we introduce protobuf via `FetchContent` in this PR, which is scoped to the PJRT plugin. Compile options will be passed by the PJRT client encoded in protobuf, and in this plugin we decode it first and then retrieve some interesting fields. ci-exactly: build_packages, test_pjrt --------- Signed-off-by: PragmaTwice Co-authored-by: Scott Todd --- integrations/pjrt/CMakeLists.txt | 4 + .../pjrt/cmake/protobuf_cc_library.cmake | 91 ++ .../pjrt/src/iree_pjrt/common/CMakeLists.txt | 1 + .../pjrt/src/iree_pjrt/common/api_impl.cc | 48 +- .../pjrt/src/iree_pjrt/common/api_impl.h | 7 +- .../third_party/pjrt_c_api/CMakeLists.txt | 9 + .../pjrt/third_party/pjrt_c_api/README.md | 3 +- .../pjrt_c_api/xla/pjrt/compile_options.proto | 169 +++ .../third_party/pjrt_c_api/xla/xla_data.proto | 1153 +++++++++++++++++ 9 files changed, 1462 insertions(+), 23 deletions(-) create mode 100644 integrations/pjrt/cmake/protobuf_cc_library.cmake create mode 100644 integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto create mode 100644 integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto diff --git a/integrations/pjrt/CMakeLists.txt b/integrations/pjrt/CMakeLists.txt index 9d904300ba99..8f54ed62dc7d 100644 --- a/integrations/pjrt/CMakeLists.txt +++ b/integrations/pjrt/CMakeLists.txt @@ -52,6 +52,10 @@ add_subdirectory("${IREE_ROOT_DIR}" "iree_core" EXCLUDE_FROM_ALL) # toolchain level. iree_setup_toolchain() +# Setup protoc and protobuf library +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") +include(protobuf_cc_library) + add_subdirectory(src) add_subdirectory(third_party/pjrt_c_api) diff --git a/integrations/pjrt/cmake/protobuf_cc_library.cmake b/integrations/pjrt/cmake/protobuf_cc_library.cmake new file mode 100644 index 000000000000..febd0f516828 --- /dev/null +++ b/integrations/pjrt/cmake/protobuf_cc_library.cmake @@ -0,0 +1,91 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include(CMakeParseArguments) +include(FetchContent) + +# disable some targets we don't use +set(protobuf_INSTALL OFF) +set(protobuf_BUILD_TESTS OFF) + +# to prevent protobuf itself from using `find_package` +set(protobuf_FORCE_FETCH_DEPENDENCIES ON) + +# pin the version of protobuf +set(protobuf_VERSION 29.1) + +FetchContent_Declare( + protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf + GIT_TAG v${protobuf_VERSION} + GIT_SHALLOW ON +) + +FetchContent_MakeAvailable(protobuf) + +# make protobuf_generate() function available +include(${protobuf_SOURCE_DIR}/cmake/protobuf-generate.cmake) + +# iree_pjrt_protobuf_cc_library() +# +# CMake function to invoke the protoc compiler. +# +# Parameters: +# NAME: name of target +# SRCS: List of source files for the library +# PROTOC_ARGS: List of protoc arguments. +# PUBLIC: Add this so that this library will be exported under iree:: +# Also in IDE, target will appear in IREE folder while non PUBLIC will be in IREE/internal. +# TESTONLY: When added, this target will only be built if user passes -DIREE_BUILD_TESTS=ON to CMake. +# +# iree_pjrt_protobuf_cc_library( +# NAME +# some_def +# SRC +# some_def.proto +# PUBLIC +# ) +function(iree_pjrt_protobuf_cc_library) + cmake_parse_arguments(_RULE + "PUBLIC;TESTONLY" + "NAME" + "SRCS;PROTOC_ARGS" + ${ARGN} + ) + + if(_RULE_TESTONLY AND NOT IREE_BUILD_TESTS) + return() + endif() + + # Prefix the library with the package name, so we get: iree_package_name + iree_package_name(_PACKAGE_NAME) + set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}") + + add_library(${_NAME} ${_RULE_SRCS}) + protobuf_generate( + TARGET ${_NAME} + LANGUAGE cpp + PROTOC_OPTIONS ${_RULE_PROTOC_ARGS} + ) + target_include_directories(${_NAME} + PUBLIC + $ + ) + target_link_libraries(${_NAME} + PUBLIC + protobuf::libprotobuf + ${IREE_DEFAULT_LINKOPTS} + ) + iree_install_targets( + TARGETS ${_NAME} + ) + + # Alias the iree_package_name library to iree::package::name. + # This lets us more clearly map to Bazel and makes it possible to + # disambiguate the underscores in paths vs. the separators. + iree_package_ns(_PACKAGE_NS) + iree_add_alias_library(${_PACKAGE_NS}::${_RULE_NAME} ${_NAME}) +endfunction() diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt index 46828f14ebba..4371cd3db965 100644 --- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt +++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( iree::vm iree::vm::bytecode::module iree_pjrt_deps::headers + iree_pjrt_deps::protos PUBLIC ) diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc index a577cb0b82ad..68b45ecbbb07 100644 --- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc +++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc @@ -6,6 +6,7 @@ #include "iree_pjrt/common/api_impl.h" +#include #include #include #include @@ -1002,7 +1003,7 @@ iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer( // Compile program and check for errors: LoadedExecutableInstance* executable; - auto* error = this->client().Compile(&program, &executable); + auto* error = this->client().Compile(&program, {}, &executable); if (error) { auto errinst = ErrorInstance::FromError(error); auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, @@ -1351,22 +1352,14 @@ void ClientInstance::BindApi(PJRT_Api* api) { LoadedExecutableInstance* executable; // Read compilation options. - // TODO: Port CompileOptionsProto into the project or leave ommitted. - // xla::CompileOptionsProto options_proto; - // if (!options_proto.ParseFromArray(args->compile_options, - // args->compile_options_size)) { - // return MakeError(iree_make_status(IREE_STATUS_INTERNAL, - // "could not parse compilation - // options")); - // } - // auto options = xla::CompileOptions::FromProto(options_proto); - // if (!options.ok()) { - // return MakeError( - // iree_make_status(IREE_STATUS_INTERNAL, - // std::string(options.status().message()).c_str())); - // } - - auto* error = client->Compile(args->program, /**options,*/ &executable); + xla::CompileOptionsProto options_proto; + if (!options_proto.ParseFromArray(args->compile_options, + args->compile_options_size)) { + return MakeError(iree_make_status(IREE_STATUS_INTERNAL, + "could not parse compilation options")); + } + + auto* error = client->Compile(args->program, options_proto, &executable); if (error) return error; args->executable = *executable; return nullptr; @@ -1451,7 +1444,7 @@ iree_status_t ClientInstance::PopulateDevices() { } PJRT_Error* ClientInstance::Compile(const PJRT_Program* program, - /*xla::CompileOptions options,*/ + xla::CompileOptionsProto options, LoadedExecutableInstance** out_executable) { std::unique_ptr artifact_tx; if (platform().artifact_dumper().enabled()) { @@ -1570,11 +1563,28 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program, output->GetDataSize())); } + // calculate devices for this computation from device assignment + std::vector devices; + + const auto& build_options = options.executable_build_options(); + if (build_options.has_device_assignment()) { + const auto& device_assignment = build_options.device_assignment(); + for (auto id : + device_assignment.computation_devices(0).replica_device_ids()) { + if (id < addressable_devices_.size()) + devices.push_back(addressable_devices_[id]); + } + } + + if (devices.empty()) { + devices = addressable_devices_; + } + auto executable = std::make_unique( *this, new ExecutableImage(std::move(output), std::string(program->code, program->code_size)), - addressable_devices_); + devices); status = executable->LoadAll(); if (!iree_status_is_ok(status)) { return MakeError(status); diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h index c6debcae8bf7..8304c8816bc4 100644 --- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h +++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h @@ -25,6 +25,7 @@ #include "iree_pjrt/common/layout_utils.h" #include "iree_pjrt/common/platform.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/compile_options.pb.h" namespace iree::pjrt { @@ -451,9 +452,9 @@ class ClientInstance { // Compiles. // See TODOs in PJRT_Client_Compile. - PJRT_Error* Compile( - const PJRT_Program* program, /*xla::CompileOptions options, */ - LoadedExecutableInstance** executable); + PJRT_Error* Compile(const PJRT_Program* program, + xla::CompileOptionsProto options, + LoadedExecutableInstance** executable); // --------------------------------------------------------------------------- // Subclass hooks. diff --git a/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt b/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt index 52e7fae256ec..94509be43f7c 100644 --- a/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt +++ b/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt @@ -26,3 +26,12 @@ iree_cc_library( "xla/pjrt/c/pjrt_c_api.h" PUBLIC ) + +iree_pjrt_protobuf_cc_library( + NAME + protos + SRCS + "xla/pjrt/compile_options.proto" + "xla/xla_data.proto" + PUBLIC +) diff --git a/integrations/pjrt/third_party/pjrt_c_api/README.md b/integrations/pjrt/third_party/pjrt_c_api/README.md index 404428d63634..c55093a30aa2 100644 --- a/integrations/pjrt/third_party/pjrt_c_api/README.md +++ b/integrations/pjrt/third_party/pjrt_c_api/README.md @@ -1,6 +1,7 @@ # pjrt_c_api -This directory contains a fork of C headers needed to build a PJRT plugin. +This directory contains a fork of C headers and .proto files +needed to build a PJRT plugin. It is intended to be sync'd with upstream for major/breaking changes and releases. diff --git a/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto new file mode 100644 index 000000000000..b3022a04e30c --- /dev/null +++ b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto @@ -0,0 +1,169 @@ +syntax = "proto3"; + +package xla; + +// TODO: to avoid introducing too many source files in XLA to IREE PJRT, +// currently we remove some fields in this file which is not in use +// so that their message definitions are not required. +// If we want to uncomment these removed fields, we should also +// add the corresponding schema files, like below. +// +// Currently, these removed fields include: +// - comp_envs, debug_options in ExecutableBuildOptionsProto +// - target_config in CompileOptionsProto + +// import "xla/stream_executor/device_description.proto"; +// import "xla/xla.proto"; +import "xla/xla_data.proto"; + +// A serialization of xla::ExecutableBuildOptions. +// Next id: 24. +message ExecutableBuildOptionsProto { + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + int64 device_ordinal = 1; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + xla.ShapeProto result_layout = 2; + + // Expose access to the XLA compilation environments, which will be passed to + // the compilation process. + // xla.CompilationEnvironmentsProto comp_envs = 13; + + // Expose access to the XLA debug options which will be passed to the + // compilation process. + // xla.DebugOptions debug_options = 3; + + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int64 num_replicas = 4; + + // The number of partitions in this computation. Defaults to 1. + int64 num_partitions = 5; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 6; + + // Whether to automatically generate XLA shardings for SPMD partitioner. + bool use_auto_spmd_partitioning = 7; + + // The amount of effort to spend on optimizing for minimizing program + // execution time, as a value in [-1.0, +1.0]. The baseline is 0.0, which + // strongly prioritizes execution time at the cost of longer compile times, + // suitable for production workloads. A value of -0.5 would be appropriate for + // research use cases that prefer faster compilations to iterate more quickly. + // Positive values, on the other hand, might enable costly optimizations that + // are off by default. + float exec_time_optimization_effort = 20; + + // The amount of effort to spend on making the program fit in memory (where + // "fit in memory" here has a backend-dependent meaning), as a value in + // [-1.0,+1.0]. The baseline is 0.0, which expends significant effort on + // attempting to make the program fit. A value of -1.0 would be appropriate + // for use cases that wish to spend minimal effort here and fail as quickly as + // possible instead. Positive values, on the other hand, might enable costly + // algorithms to reduce memory usage that are off by default. + float memory_fitting_effort = 21; + + // Whether HLOs should be deduplicated. + bool deduplicate_hlo = 8; + + // If set, this specifies a static device assignment for the computation. + // Otherwise, the computation will be compiled generically and can be run with + // any device assignment compatible with the computation's replica and + // partition counts. + xla.DeviceAssignmentProto device_assignment = 9; + + // Whether input and output buffers are aliased if the associated parameter is + // passed-through XLA modules without being changed. + bool alias_passthrough_params = 10; + + // By default, XLA builds an executable by invoking standard compilation, i.e. + // running Compiler::Compile, or both Compiler::RunHloPasses and + // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an + // executable by invoking only RunBackend and skip invoking RunHloPasses, + // which can be used to compile post-optimizations HLO modules. + bool run_backend_only = 11; + + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 18; + + // Allows sharding propagation to propagate to the outputs. This changes the + // output shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the output + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + // This is a vector of bool, because the user can control (if the output of + // the computation is a tuple) which elements of the tuple can have the + // sharding substituted and which don't. If only one boolean value is passed + // in the vector that's interpreted as the value to be applied for every + // single element of the output tuple. One value per element of the tuple + // means that each value is attached to one of the output elements. + repeated bool allow_spmd_sharding_propagation_to_output = 12; + + // Opaque profile data for any feedback directed optimizations. + bytes fdo_profile = 14; + + int64 device_memory_size = 15; + + // Mesh shape in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_shape = 16; + + // Mesh ids in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_ids = 17; + + // Use Shardy, a new partitioner, to replace the existing + // ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for + // details. + bool use_shardy_partitioner = 19; + + int64 process_index = 22; + int64 process_count = 23; +} + +message OptionOverrideProto { + oneof value { + string string_field = 1; + bool bool_field = 2; + int64 int_field = 3; + double double_field = 4; + } +} + +message CompileOptionsProto { + // Refer CompileOptions for documentation of fields. + repeated ShapeProto argument_layouts = 1; + bool parameter_is_tupled_arguments = 2; + ExecutableBuildOptionsProto executable_build_options = 3; + bool compile_portable_executable = 4; + int64 profile_version = 5; + bytes serialized_multi_slice_config = 6; + map env_option_overrides = 7; + + // stream_executor.GpuTargetConfigProto target_config = 8; +} + +// Helper for serializing opaque executables alongside CompileOptions. +message ExecutableAndOptionsProto { + bytes serialized_executable = 1; + CompileOptionsProto compile_options = 2; +} diff --git a/integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto b/integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto new file mode 100644 index 000000000000..7d9563b11ab7 --- /dev/null +++ b/integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto @@ -0,0 +1,1153 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Primitive types are the individual values that can be held in rectangular +// multidimensional arrays. A description of the rectangular multidimensional +// array dimensions / primitive type is given by Shape, below. +// +// LINT.IfChange +enum PrimitiveType { + // Invalid primitive type to serve as default. + PRIMITIVE_TYPE_INVALID = 0; + + // Predicates are two-state booleans. + PRED = 1; + + // Signed integral values of fixed width. + S2 = 26; + S4 = 21; + S8 = 2; + S16 = 3; + S32 = 4; + S64 = 5; + + // Unsigned integral values of fixed width. + U2 = 27; + U4 = 22; + U8 = 6; + U16 = 7; + U32 = 8; + U64 = 9; + + // Floating-point values of fixed width. + // + // Note: if f16s are not natively supported on the device, they will be + // converted to f16 from f32 at arbirary points in the computation. + F16 = 10; + F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + + F64 = 12; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433 + // + // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the + // existing IEEE types. + // + // F8E4M3 has 4 exponent bits and 3 mantissa bits, and is similar to the + // existing IEEE types. + // + // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only + // Finite and NaN values are supported. Unlike IEEE types, infinities are not + // supported. NaN is represented when the exponent and mantissa bits are all + // 1s. All other values are finite. + // + // F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The + // "FNUZ" means only Finite and NaN values are supported; zero is unsigned. + // Unlike IEEE types, infinities are not supported. NaN is represented when + // the exponent and mantissa bits are all 0s with a sign bit of 1. All other + // values are finite. + // + // F8E3M4 has 3 exponent bits and 4 mantissa bits, and is similar to the + // existing IEEE types. + // + // Support for these dtypes is under development. They do not yet work + // properly in most cases. + // TODO(b/259609697): Fully support FP8. + F8E5M2 = 19; + F8E4M3 = 28; + F8E4M3FN = 20; + F8E4M3B11FNUZ = 23; + F8E3M4 = 29; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 + // + // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits. + // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits. + // + // The "FNUZ" means only Finite and NaN values are supported; zero is + // unsigned. Unlike IEEE types, infinities are not supported. NaN is + // represented when the exponent and mantissa bits are all 0s with a sign bit + // of 1. All other values are finite. + // + // These differences mean there's an additional exponent value available. To + // keep the same dynamic range as an IEEE-like FP8 type, the exponent is + // biased one more than would be expected given the number of exponent bits + // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ). + F8E5M2FNUZ = 24; + F8E4M3FNUZ = 25; + + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A tuple is a polymorphic sequence; e.g. a shape that holds different + // sub-shapes. They are used for things like returning multiple values from a + // computation; e.g. a computation that returns weights and biases may have a + // signature that results in a tuple like (f32[784x2000], f32[2000]) + // + // If a shape proto has the tuple element type, it may not have any entries + // in the dimensions field. + TUPLE = 13; + + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. + // + // (OPAQUE would be a better name for this identifier, but that conflicts with + // a macro defined in windows.h.) + OPAQUE_TYPE = 14; + + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 30 +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc +// ) + +// Describes the padding configuration for Pad operation. The padding amount on +// both edges as well as between the elements are specified for each dimension. +message PaddingConfig { + // Describes the padding configuration for a dimension. + message PaddingConfigDimension { + // Padding amount on the low-end (next to the index 0). May be negative. + int64 edge_padding_low = 1; + + // Padding amount on the high-end (next to the highest index). May be + // negative. + int64 edge_padding_high = 2; + + // Padding amount between the elements. May not be negative. + int64 interior_padding = 3; + } + + // The padding configuration for all dimensions. + repeated PaddingConfigDimension dimensions = 1; +} + +// A DimLevelType indicates the encoding method for a dimension in an array. +// The semantics of this field are identical to those of the MLIR SparseTensor +// dialect. +// This should be kept in sync with the SparseTensor DimLevelType enum: +// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86 +enum DimLevelType { + // The corresponding dimension is Dense, every entry is stored. + DIM_DENSE = 0; + // The corresponding dimension is Compressed, only nonzeros are stored. + DIM_COMPRESSED = 1; + // The corresponding dimension contains a single coordinate, no sibling + // elements for each parent. + DIM_SINGLETON = 2; + // The corresponding dimension is Compressed, but with potential trailing + // zeros, thus an extra upper bound (high) is used to exclude those zeros. + // E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)]. + DIM_LOOSE_COMPRESSED = 3; +} + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/xla/docs/tiled_layout.md for details about tiling-based +// layout. +message TileProto { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + +// Describes how data should be split between different memories. +message SplitConfigProto { + // The dimension that is split. + int64 dimension = 1; + // The indices where each split point occurs. For example, if the dimension + // size is 1024, a split_indices value of {512} indicates a two-way split of + // data through the middle. + repeated int64 split_indices = 2; +} + +// A layout describes how the array is placed in (1D) memory space. This +// includes the minor-to-major ordering of dimensions within a shape. +// +// Clients must specify the layouts of input Literals to the +// computation. Layouts specified in interior operations which take Shapes (for +// example, Convert) are ignored. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message LayoutProto { + // The dimension level type list for this array, specifying the way in which + // each array dimension is represented in memory. If this list is empty, the + // array is assumed to be dense. + repeated DimLevelType dim_level_types = 9; + + // Whether each dimension is unique or ordered. Each of the following lists + // must be empty, or have one entry for each entry of dim_level_types. If + // either list is empty, all dimensions are assumed to be unique and ordered, + // respectively. Entries in this list may not be false for some DimLevelType + // values (such as DIM_DENSE in particular). + repeated bool dim_unique = 13; + repeated bool dim_ordered = 14; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). This field is required. + repeated int64 minor_to_major = 1; + + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated TileProto tiles = 6; + + // The shape is padded at the end to multiple of, in terms of number of + // elements. This is useful when tiling does not bring the shape to certain + // desired granules. Tiling effectively pads/reshapes/transposes the shape + // to another shape. This field pads the total number of elements of that + // new shape to a multiple of certain number of elements. This is useful such + // as we want a layout which does not tile the data but still requires it to + // be padded to certain number of elements. + int64 tail_padding_alignment_in_elements = 16; + + // (Optional) Bit size of each element. When unspecified or being 0, default + // to ShapeUtil::ByteSizeOfPrimitiveType. + int64 element_size_in_bits = 7; + + // Memory space where this array resides. The integer field is interpreted in + // a backend-specific manner. + int64 memory_space = 8; + + // The integer types to be used for indices and pointers. These fields must + // not be used unless the layout represents a sparse array. The PrimitiveType + // must correspond to an unsigned integer (U8, U16, U32, or U64). + // If not provided, the compiler will use the largest unsigned integer + // that is naturally supported by the target device (U32 or U64 in currently + // supported devices). + PrimitiveType index_primitive_type = 11; + PrimitiveType pointer_primitive_type = 12; + + // The physical, on-device shape used to represent the shape this layout + // belongs to. Only used for sparse arrays. + // The layout(s) contained within the physical shape should not also contain + // a physical shape. + ShapeProto physical_shape = 10; + + // The dynamic shape metadata size in bytes in front of the shape data. The + // field may be non-zero for a static shape whose associated buffer is for a + // dynamic shape, e.g. a result of SliceToDynamic. + int64 dynamic_shape_metadata_prefix_bytes = 15; + + // The split configurations which describe if/how the data is split between + // different memories. + repeated SplitConfigProto split_configs = 17; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. + + reserved 2; + reserved "padded_dimensions"; + reserved 3; + reserved "padding_value"; + reserved 4; + reserved "format"; + reserved 5; + reserved "max_sparse_elements"; +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) + +// A shape describes the number of dimensions in the array, the size of each +// dimension, and the primitive component type. +// +// Tuples are a special case in that they have rank zero and have tuple_shapes +// defined. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message ShapeProto { + reserved 1; + reserved "rank"; + + // The element type for this shape. + PrimitiveType element_type = 2; + + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. + repeated int64 dimensions = 3; + + // For tuples only, the shapes of constituent shapes in the tuple sequence. + repeated ShapeProto tuple_shapes = 4; + + // The layout used to back this shape. + LayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) + +// Shape of the parameters and output of a computation (like a traditional +// function signature). +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; + repeated string parameter_names = 3; +} + +// Statistics of a computation. +message ComputationStats { + // The number of floating point operations in the computation. + double flop_count = 1; + + // The number of transcendental operations (e.g., exp) in the computation. + double transcendental_count = 2; +} + +// The type optimization profiles in use for Op-level optimizations. +enum ProfileType { + INVALID = 0; + WINDOW = 1; + FLAG = 2; + INTEGER = 3; +} + +// The source of the optimization profile. +enum ProfileSource { + PROFILE_SOURCE_UNKNOWN_SOURCE = 0; + PROFILE_SOURCE_EMBEDDED = 1; + PROFILE_SOURCE_REMOTE = 2; +} + +// The compilation event that triggered the use of the profile. +enum CompilationEvent { + COMPILATION_EVENT_UNKNOWN_EVENT = 0; + COMPILATION_EVENT_FIRST_COMPILATION = 1; + COMPILATION_EVENT_RECOMPILATION = 2; +} + +// Symbolization metadata for HLO Instructions. +// +// This metadata is used for debugging XLA code generation, as well as +// performance profiling of XLA-generated executables. +message OpMetadata { + // The framework op name that generated this XLA op. + // + // Frameworks that build on top of XLA should mirror the names of their ops + // back to users by specifying the op_type. In this way, even if the + // framework's "ops" are implemented as multiple XLA HLO Ops, they can be + // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as + // multiple ops, then each op should have the op_type be "SoftMax".) + string op_type = 1; + // The user-specified name of the op. + // + // This name is often unique within a computation. Note: some frameworks + // add auto-generated names if the user does not provide one. + string op_name = 2; + // Indicate a file and line that this op is associated to in a user's program. + // + // e.g. it could be the file and line of user code that generated the op. + string source_file = 3; + int32 source_line = 4; + + // Deprecated, use [ProfileInfo][profile_type] instead. + repeated ProfileType profile_type = 5 [deprecated = true]; + + reserved 6; + reserved "creation_pass_id"; + + reserved 7; + reserved "logical_creation_pass_id"; + + // The footprint of the generated code for the instruction. + int64 size_of_generated_code_in_bytes = 8; + // The size of the working set, i.e., the amount of memory, used by the + // instruction in a compiler-managed fast device memory. + int64 size_of_memory_working_set_in_bytes = 9; + + // Information about the optimization profile that this operation contains. + message ProfileInfo { + // The type of optimization profiles that this operation contains. + repeated ProfileType profile_type = 1; + // Speedup of tuned config compared to default config. + // TODO(b/203817882) Set the relative_speedup. + double relative_speedup = 2; + // The source of the optimization profiles that this operation contains. + ProfileSource profile_source = 3; + // The compilation event that triggered the use of the profiles. + CompilationEvent compilation_event = 4; + } + + // Profile information for the Op. + ProfileInfo profile_info = 10; + + // Deduplicated HLO name for this op. In some cases, we can have multiple + // instructions (e.g. fusions) that are considered duplicates. We want to + // group them together under the same name so that we can group them together + // during analysis (e.g. HLO Op Profile tool in Xprof). + // E.g. If we have fusion.1, fusion.2, and fusion.3 marked as duplicates, + // fusion.2 and fusion.3 will have deduplicated_name = fusion.1 + string deduplicated_name = 12; + + // Whether to preserve the layout of the HLO op. + bool preserve_layout = 13; + + // 1-based position of the frame in frames flat array. + // Ids are 1-based to keep 0 value as representation of non-set property. + int32 stack_frame_id = 15; + + // Instruction name available upon scheduling. + string scheduling_name = 16; + + reserved 14; +} + +// Profile data from the execution of a computation. +message ExecutionProfile { + // Whether the executable was read from the compilation cache. + bool compilation_cache_hit = 1; + + // The time in milliseconds spent to compile the computation. This only set if + // the executable was not read from the compilation cache + // (compilation_cache_hit == false). + int64 compile_time_ms = 2; + + // The number of cycles spent for the computation. This does not include the + // time taken for the data transfers between the host and the device. This is + // a target-dependent field and only used for debugging purposes. + int64 compute_cycle_count = 3; + + // The time in nanoseconds spent for the computation, without data transfer. + int64 compute_time_ns = 4; + + // The time in nanoseconds spent for the entire computation, including the + // result data transfer time. Current implementation does not spend any cycles + // for the input data transfer since the memory is initialized with the proper + // values before the execution. + int64 compute_and_transfer_time_ns = 5; + + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; + + // Whether this profile was drawn from a cache of profiles instead of from + // execution on the hardware. + bool profile_cache_hit = 7; + + // Whether a warm-up run of the computation was executed before the + // measured execution. + bool warmup_run_executed = 8; +} + +// Handle given to a user that represents an execution that the user launched +// asynchronously on the device. +message ExecutionHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a globally accessible allocation. +// Contrast this against a ComputationDataHandle, which is not globally +// accessible, since it only exists within a specific computation. +message GlobalDataHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a replicated virtual device. Each +// replicated device represents N physical devices for execution where N is the +// number of replicas. +message DeviceHandle { + int64 handle = 1; + + // The number of model-parallel virtual devices that communicate via XLA + // Send/Recv instructions. + int64 device_count = 2; +} + +// Handle given to a user to represent a channel between two computations +// via a Send and Recv instruction pair. Channels are unbuffered, so Send +// Send instructions will be blocked until the data is transferred. +message ChannelHandle { + int64 handle = 1; + enum ChannelType { + // Invalid primitive type to serve as default. + CHANNEL_TYPE_INVALID = 0; + + // A channel for sending data between devices. + DEVICE_TO_DEVICE = 1; + + // A channel for sending data from the device to the host. Can only be used + // with a Send operation. + DEVICE_TO_HOST = 2; + + // A channel for sending data from the host to the device. Can only be used + // with a Recv operation. + HOST_TO_DEVICE = 3; + } + ChannelType type = 2; +} + +// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which +// represents the device ids assigned to a set of replicated computations. +// See xla::DeviceAssignment class comment for more details. +message DeviceAssignmentProto { + int32 replica_count = 1; + int32 computation_count = 2; + + // Each logical computation runs on replica_count physical devices. + // ComputationDevice represents the device ids assinged to the replicas. + message ComputationDevice { + repeated int64 replica_device_ids = 1; + } + repeated ComputationDevice computation_devices = 3; +} + +// Literals are used when the server and client need to exchange materialized +// data / results. Literals are also used to describe constants used in +// computations. +// +// Transfers to/from the client are encoded in literal form, and the structure +// of the repeated fields is implied by the shape. +message LiteralProto { + ShapeProto shape = 1; + repeated bool preds = 2; + bytes s2s = 26; + bytes s4s = 21; + bytes s8s = 15; + bytes u2s = 27; + bytes u4s = 22; + bytes u8s = 3; + repeated int32 s32s = 4; + repeated int64 s64s = 5; + repeated uint32 u32s = 6; + repeated uint64 u64s = 7; + repeated float f32s = 8; + repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. + repeated LiteralProto tuple_literals = 10; + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; + bytes f8e5m2s = 19; + bytes f8e4m3s = 28; + bytes f8e4m3fns = 20; + bytes f8e4m3b11fnuzs = 23; + bytes f8e5m2fnuzs = 24; + bytes f8e4m3fnuzs = 25; + bytes f8e3m4s = 29; + repeated int64 sparse_indices = 14; + // Next = 30 +} + +message WindowDimension { + // The size of the window in this dimension. For a rectangle, this would be + // the width or height. + int64 size = 1; + + // The stride at which the window moves across the base area in this + // dimension. In other words, this is the spacing between different + // positions of the window in this dimension. + int64 stride = 2; + + // If positive, means the amount of padding to add to the base area at the low + // end of this dimension; if negative, its negative means the number of + // elements removed from the low end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of padding + // values to pad on the left, given that indices increase when going right. + // The actual padding value depends upon the context. Convolution pads with + // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + // init value. + int64 padding_low = 3; + + // As padding_low, but on the high end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of values to + // pad on the right, given that indices increase when going right. + int64 padding_high = 4; + + // Dilation factor of the sliding window in this dimension. A dilation factor + // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are + // implicitly placed between each kernel element. This value may not be less + // than 1. See documentation for convolution. + int64 window_dilation = 5; + + // Dilation factor of the base area in this dimension. A dilation factor of 1 + // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly + // placed between each base area element. This value may not be less than 1. + // See documentation for convolution. + int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; +} + +// Describes the windowing in an operation such as convolution. +// +// The window is moved across a base area and for each position of the +// window a computation is performed. The field below describes the +// window and the movement of the window across a base area. +message Window { + repeated WindowDimension dimensions = 1; +} + +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the start_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in collapsed_slice_dims + // then 0 + // else Out[offset_dims[i++]] + repeated int64 offset_dims = 1; + repeated int64 collapsed_slice_dims = 2; + + // This is interpreted as a map from i to start_index_map[i]. It + // transforms the gather index looked up from the start_indices tensor into + // the starting index in the input space. + repeated int64 start_index_map = 3; + + // The dimension in the start_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; + + // This is the batch dimensions in the operand. + repeated int64 operand_batching_dims = 5; + + // This is the batch dimensions in the index, and it should be the same size + // as operand_batching_dims. + repeated int64 start_indices_batching_dims = 6; +} + +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; + + // This is the batch dimension in the input. + repeated int64 input_batching_dims = 5; + + // This is the batch dimension in the index. + repeated int64 scatter_indices_batching_dims = 6; +} + +message ConvolutionDimensionNumbers { + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; + + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; + + // The number of the dimension that represents input features in the + // convolutional kernel (rhs). + int64 kernel_input_feature_dimension = 3; + + // The number of the dimension that represents output features in + // the convolutional kernel (rhs). + int64 kernel_output_feature_dimension = 4; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the kernel (rhs). window.strides(0) is the + // stride in the kernel_spatial_dimensions(0) dimension. + repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 +} + +enum PaddingType { + PADDING_INVALID = 0; + PADDING_VALID = 1; // Only valid portion of the base are covered. + PADDING_SAME = 2; // Extra is added to produce same output size as the input. +} + +enum FftType { + FFT = 0; // Forward FFT; complex in, complex out. + IFFT = 1; // Inverse FFT; complex in, complex out. + RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out + IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, + // fft_length real out +} + +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +} + +message RaggedDotDimensionNumbers { + // The contracting and batch dimensions of the 'lhs' and 'rhs'. + DotDimensionNumbers dot_dimension_numbers = 1; + // The dimension numbers that represent the 'lhs' ragged dimensions. + repeated int64 lhs_ragged_dimensions = 2; + // The dimension numbers that represent the 'rhs' group dimensions. + repeated int64 rhs_group_dimensions = 3; +} + +enum SparsityType { + SPARSITY_INVALID = 0; + + // Structured N:M sparsity. + SPARSITY_STRUCTURED_N_M = 1; + + // Next: 2 +} + +// Contains sparsity metadata for a sparse dot operation. +// The only supported type atm is structured 2:4 sparsity, which is natively +// supported on NVidia GPUs. +// Restrictions: +// - only one operand of the dot operation may be sparse; +// - only the contracting dimension may be sparse. +message SparsityDescriptor { + SparsityType type = 1; + + // Sparse operand index (0 or 1). + int32 index = 2; + // Sparse dimension number. + int32 dimension = 3; + + // Structured N:M sparsity (N < M). + int32 n = 4; + int32 m = 5; + + // Next: 6 +} + +enum RandomDistribution { + RNG_INVALID = 0; + + // Creates a uniform-distribution-generated random number on the semi-open + // interval [parameter[0], parameter[1]). + RNG_UNIFORM = 1; + + // Creates a normal-distribution-generated random number with mean + // parameter[0] and standard deviation parameter[1]. + RNG_NORMAL = 2; + + // Next: 4 +} + +enum RandomAlgorithm { + RNG_DEFAULT = 0; // Backend dependent default algorithm. + RNG_THREE_FRY = 1; + RNG_PHILOX = 2; + // Next: 2 +} + +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + } + Transpose transpose_a = 4; +} + +message CholeskyOptions { + // If true, uses the lower triangle of `a`. If false, uses the upper triangle + // of `a`. + bool lower = 1; +} + +// Attributes of the sort custom call (cub::DeviceRadixSort). +message SortOptions { + bool descending = 1; +} + +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + +// Represents a single statistic to track. +message Statistic { + // Must be a single word consisting of any alphanumeric characters + string stat_name = 1; + // Must be within a range of [0, 100], in order for the graph dumper to + // properly render the statistic onto the graph. + double stat_val = 2; +} + +// Represents the information needed to visualize propagation statistics when +// rendering an HLO graph. This includes an array of statistics as well as the +// index of the statistic to render. +message StatisticsViz { + int64 stat_index_to_visualize = 1; + repeated Statistic statistics = 2; +} + +// LINT.IfChange +message OpSharding { + enum Type { + // This sharding is replicated across all devices (implies maximal, + // all other fields are unused). + REPLICATED = 0; + // This sharding is maximal - one device runs the entire operation. + MAXIMAL = 1; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; + // This op is manually sharded: the shapes are already partitioned and the + // partitioner should not change this op. + MANUAL = 4; + // This sharding is a placeholder sharding with lowest precedence, it can be + // overwriten by any other shardings. + UNKNOWN = 5; + } + Type type = 1; + // The shape of the sharded tile. + ShapeProto tile_shape = 2; + // The shape of the tile assignment tensor - this must be the same rank as + // tile_shape and the product of its dimensions must equal + // tile_assignment_devices.size(). + repeated int64 tile_assignment_dimensions = 3; + // Flattened list of device IDs. The order of flattening is the same as used + // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). + // Only one of tile_assignment_devices and iota_dimensions shall be non-empty. + repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; + + // Only used for OTHER type. If true, data is sharded according to other + // dimensions of tile_assignment(), but replicated across devices along the + // last dimension. (Experimental) + bool replicate_on_last_tile_dim = 6; + // This field is used to track the source of this sharding, usually derived + // from instructions. Multple metadata may be populated if sharding is + // combined with other shardings. Metadata are to not be populated when + // type == TUPLE and instead metadata should be set on individual tuple + // elements. + repeated OpMetadata metadata = 7; + + // This field is used to represented the sharding type of each subgroup. + // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ + // replicate, manual, unreduced}} means that each of the last 3 dimensions + // in [2,2,2,2] represents a subgrouping in replicate, manual, + // unreduced sharding type respectively. + repeated Type last_tile_dims = 8; + + // Dimensions used to reshape the 1D iota array of device IDs. + // Only one of tile_assignment_devices and iota_reshape_dims shall be + // non-empty. + repeated int64 iota_reshape_dims = 9; + + // Dimension permutations to transposed the iota array reshaped to + // iota_reshape_dims. This must have the same size as iota_reshape_dims. + repeated int32 iota_transpose_perm = 10; + + // This field decides whether this op is in a shard group. + bool is_shard_group = 11; + + // This field is used to store the unique id of the shard group. + int64 shard_group_id = 12; + + // Used to decide whether this op is to be sharded like some other ops, or to + // which other ops will be sharded like. + enum ShardGroupType { + // This op will be sharded exactly the same as the other op. (hard + // restriction) + AS = 0; + // This op will try to allow sharding propagation within the same group even + // there is no data dependencies among them, but there is no guarantee that + // the final shardings within the same group will be exactly the same. (soft + // restriction) + LIKE = 1; + } + + ShardGroupType shard_group_type = 13; +} +// LINT.ThenChange() + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some ops (e.g., all-to-all). + repeated int64 replica_ids = 1; +} + +// Represents a list of replica groups (a list of list of devices) with +// reshaping and transposing an iota array (iota tile assignment). Can be used +// to represent certain common patterns of device lists in a compact, scalable +// format. +message IotaReplicaGroupListProto { + // Number of replica groups. + int64 num_replica_groups = 1; + + // Number of devices per group. + int64 num_devices_per_group = 2; + + // The dimensions used to reshape the 1D iota array of device IDs. + repeated int64 iota_reshape_dims = 3; + + // The dimension permutations to transposed the iota array reshaped to + // iota_reshape_dims. This must have the same size as iota_reshape_dims. + repeated int32 iota_transpose_perm = 4; +} + +// Represents a series of devices participating in a collective operation (e.g., +// all-reduce and all-to-all). While this directly translates to a list of +// replica groups, it may be used to represent these lists in a compact form. +message CollectiveDeviceListProto { + // ReplicaGroupV1: List of replica groups. Legacy way of representing device + // lists. + repeated ReplicaGroup replica_groups = 1; + + // ReplicaGroupV2: Represents a list of replica groups with reshaping and + // transposing an iota array. + IotaReplicaGroupListProto iota_replica_group_list = 2; +} + +// Describes the source target pair in the collective permute op. +message SourceTarget { + int64 source = 1; + int64 target = 2; +} + +// Describes the types of accuracy the user can request for unary ops with +// multiple implementations. +message ResultAccuracy { + enum Mode { + DEFAULT = 0; + HIGHEST = 1; + } + message Tolerance { + // Absolute error tolerance for unary instructions. + double atol = 1; + // Relative error tolerance for unary instructions. + double rtol = 2; + // The error in ulps (units in the last place) is relative to machine + // precision. + int64 ulps = 3; + } + oneof specs { + // Choose either DEFAULT or HIGHEST precision implementation. + Mode mode = 1; + Tolerance tolerance = 2; + } +} + +// Used to indicate the precision configuration. It has backend specific +// meaning. +message PrecisionConfig { + enum Precision { + DEFAULT = 0; + HIGH = 1; + HIGHEST = 2; + // Each U8/S8 value in a tensor actually represents 2 nibble values. + PACKED_NIBBLE = 3; + + // Next: 4 + } + + // The algorithm used to evaluate the instruction. + // + // The naming convention for the dot instruction is + // ALG_DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE + // and ACCUM_TYPE correspond to the types in the "primitive dot operations" + // (such as TensorCore operations) and NUM_OPS is the number of such + // operations used per "primitive tile". When the NUM_OPS + // field is skipped, it is assumed to be 1. The types mentioned in the name + // are independent of the storage types. + // + // In general ATYPE and BTYPE are the precisions that the LHS and RHS of the + // operation are rounded to and ACCUMTYPE is the accumulation type. If a + // backend does not support the given algorithm, an error is raised. The + // Algorithm enum is intended to eventually replace the Precision enum. + // + enum Algorithm { + // If the algorithm is `ALG_UNSET`, we will decide the algorithm based on + // the operand_precision values (for now). + ALG_UNSET = 0; + // The storage type can be any 8-bit floating point type. + ALG_DOT_ANY_F8_ANY_F8_F32 = 1; + // The storage type can be any 8-bit floating point type. Intermediate + // results will not periodically be promoted to a higher precision. This + // corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's + // maxNumImpreciseAcc=32 setting may be similar. + ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM = 2; + ALG_DOT_F16_F16_F16 = 3; + ALG_DOT_F16_F16_F32 = 4; + ALG_DOT_BF16_BF16_BF16 = 5; + ALG_DOT_BF16_BF16_F32 = 6; + // An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better + // precision. + ALG_DOT_BF16_BF16_F32_X3 = 7; + // An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_BF16_BF16_F32_X6 = 8; + ALG_DOT_TF32_TF32_F32 = 9; + // An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_TF32_TF32_F32_X3 = 10; + ALG_DOT_F32_F32_F32 = 11; + ALG_DOT_F64_F64_F64 = 12; + + // Next: 13 + } + + repeated Precision operand_precision = 1; + + // Currently doesn't do anything, but we plan to support it for dot and + // possibly more instructions. + // + // TODO(b/316147294): Support this on GPU and add this to StableHLO as well. + // + // If this is set, then `operand_precision` should be set to DEFAULT and it + // will be ignored. + Algorithm algorithm = 2; + + // Next: 8 +} + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} + +// A backend-config for kWhile loops that stores the loop's trip count, if it is +// known. +// +// This is useful for backends that can implement a `for i in 0..N` loop more +// efficiently than a `while` loop. For example, on GPUs, we can implement a +// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, +// whereas implementing a `while` loop requires a host-device sync on each +// iteration. +message WhileLoopBackendConfig { + message KnownTripCount { + int64 n = 1; + } + // This indirection lets us distinguish between known-trip-count == 0 and + // unknown-trip-count. + KnownTripCount known_trip_count = 1; +} + +// Specifies a pair of output/operand buffers that alias each other for +// kCustomCall and kFusion +message OutputOperandAliasing { + repeated int64 output_shape_index = 1; + int64 operand_index = 2; + repeated int64 operand_shape_index = 3; +} + +message OriginalArrayProto { + repeated int64 leaf_shape_index = 1; + string instruction_name = 2; + repeated int64 shape_index = 3; +} + +message OriginalValueProto { + repeated OriginalArrayProto leaves = 1; +}