From 91b65817a32ebb90072af0e4a19567907d81c66d Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Fri, 3 Jan 2025 17:13:33 -0800 Subject: [PATCH] Change IFRT and PjRt layout API to return `std::shared_ptr` instead of `std::unique_ptr` The current API design that uses `std::unique_ptr` has several issues: * The API requires `xla::PjRtLayout` to be copied in some scenarios, e.g., `xla::ifrt::Array` internally stores a layout and returns its copy every time `layout()` is called. This forces implementations to break the abstraction boundary because `xla::PjRtLayout` is an abstract class and `std::unique_ptr` is not copyable. The current implementation either stores `xla::Layout` and creates `xla::PjRtLayout` every time, or downcasts `xla::PjRtLayout` to `xla::PjRtXlaLayout` to perform the copy. * `xla::Layout` is expensive to copy (`sizeof(xla::Layout)` is 248 bytes as of 2025-01-03) and copying `xla::PjRtXlaLayout` requires copying or moving `xla::Layout`. To address these two problems, this CL changes PjRt and IFRT APIs that return `xla::PjRtLayout` to instead use `std::shared_ptr`, so that PjRt layouts can be cheaply copied. Similar patterns have been used in other places such as `xla::ifrt::Sharding` and `xla::PjRtExecutable::GetHloModules()`. Some implementations have been updated to take advantage of this change. For example, `PjRtCApiBuffer::layout()` no longer performs a layout copy and instead reuses an internally cached instance of `std::shared_ptr`. PiperOrigin-RevId: 711892970 --- xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 8 +++--- xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 2 +- xla/pjrt/pjrt_c_api_client.cc | 11 ++++---- xla/pjrt/pjrt_c_api_client.h | 4 +-- xla/pjrt/pjrt_client.h | 6 ++--- xla/pjrt/pjrt_executable.cc | 8 +++--- xla/pjrt/pjrt_executable.h | 4 +-- xla/pjrt/pjrt_layout.h | 6 ++--- xla/python/ifrt/array.h | 2 +- xla/python/ifrt/client.h | 2 +- xla/python/ifrt/executable.h | 8 +++--- xla/python/ifrt/mock.cc | 7 ++--- xla/python/ifrt/mock.h | 12 ++++----- xla/python/ifrt_proxy/client/array.h | 2 +- xla/python/ifrt_proxy/client/client.h | 7 ++--- xla/python/ifrt_proxy/client/executable.cc | 27 +++++-------------- xla/python/ifrt_proxy/client/executable.h | 11 +++++--- .../ifrt_proxy/client/executable_test.cc | 21 ++++++++------- xla/python/ifrt_proxy/server/ifrt_backend.cc | 4 +-- .../ifrt_proxy/server/ifrt_backend_test.cc | 10 +++---- xla/python/jax_jit.cc | 6 ++--- xla/python/jax_jit.h | 2 +- xla/python/pjrt_ifrt/basic_string_array.cc | 6 +++-- xla/python/pjrt_ifrt/basic_string_array.h | 3 ++- xla/python/pjrt_ifrt/pjrt_array.cc | 8 +++--- xla/python/pjrt_ifrt/pjrt_array.h | 2 +- xla/python/pjrt_ifrt/pjrt_client.cc | 2 +- xla/python/pjrt_ifrt/pjrt_client.h | 6 ++--- xla/python/pjrt_ifrt/pjrt_executable.h | 8 +++--- xla/python/py_array.cc | 1 + xla/python/py_array.h | 2 +- xla/python/py_client.cc | 3 ++- xla/python/py_compile_only_client.cc | 2 +- xla/python/py_executable.cc | 4 +-- xla/python/py_executable.h | 8 +++--- .../functional_hlo_runner.cc | 4 +-- 36 files changed, 114 insertions(+), 115 deletions(-) diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index ec697b08af7841..64aa20bac3c0e2 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -1797,10 +1797,10 @@ PJRT_Error* PJRT_Buffer_GetMemoryLayout( absl::MutexLock lock(&args->buffer->mu); if (!layout_data.has_value()) { // TODO(skyewm): change PJRT C API to also use opaque layout type - std::unique_ptr pjrt_layout = + std::shared_ptr pjrt_layout = args->buffer->buffer->layout(); - xla::PjRtXlaLayout* pjrt_xla_layout = - tensorflow::down_cast(pjrt_layout.get()); + const xla::PjRtXlaLayout* pjrt_xla_layout = + tensorflow::down_cast(pjrt_layout.get()); CHECK(pjrt_xla_layout != nullptr) << "Got unexpected layout type"; const xla::Layout& xla_layout = pjrt_xla_layout->xla_layout(); @@ -2283,7 +2283,7 @@ PJRT_Error* PJRT_Layouts_PJRT_Client_GetDefaultLayout( args->client->client->GetDefaultLayout( pjrt::ConvertFromPjRtBufferType(args->type), {args->dims, args->num_dims})); - auto pjrt_xla_layout = std::make_unique(xla_layout); + auto pjrt_xla_layout = std::make_shared(xla_layout); args->layout = new PJRT_Layouts_MemoryLayout{std::move(pjrt_xla_layout)}; return nullptr; } diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 0ebecc0c251734..04463410ee7e08 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -218,7 +218,7 @@ struct PJRT_CopyToDeviceStream { }; struct PJRT_Layouts_MemoryLayout { - std::unique_ptr layout; + std::shared_ptr layout; }; struct PJRT_Layouts_SerializedLayout { diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index a1b8966bd34e9b..18ca751766412b 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2020,16 +2020,17 @@ absl::Span PjRtCApiBuffer::dimensions() const { return absl::Span(args.dims, args.num_dims); } -std::unique_ptr PjRtCApiBuffer::layout() const { +std::shared_ptr PjRtCApiBuffer::layout() const { { absl::MutexLock lock(&mu_); - if (!layout_.has_value()) { + if (layout_ == nullptr) { const PJRT_Api* c_api = pjrt_c_api(); PJRT_Layouts_Extension* extension = pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts); if (extension == nullptr) { - layout_.emplace(LayoutUtil::MakeDescendingLayout(dimensions().size())); + layout_ = std::make_shared( + LayoutUtil::MakeDescendingLayout(dimensions().size())); } else { std::unique_ptr @@ -2057,11 +2058,11 @@ std::unique_ptr PjRtCApiBuffer::layout() const { absl::StatusOr pjrt_xla_layout = PjRtXlaLayout::Deserialize(serialized_layout); TF_CHECK_OK(pjrt_xla_layout.status()); - layout_.emplace(*pjrt_xla_layout); + layout_ = std::make_shared(*std::move(pjrt_xla_layout)); } } } - return std::make_unique(*layout_); + return layout_; } bool PjRtCApiBuffer::has_dynamic_dimensions() const { diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index 46304e6d46bcef..03e41ec3985903 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -485,7 +485,7 @@ class PjRtCApiBuffer : public PjRtBuffer { absl::Span dimensions() const override; - std::unique_ptr layout() const override; + std::shared_ptr layout() const override; // PJRT C API doesn't support tuple buffers. bool IsTuple() const override { return false; } @@ -583,7 +583,7 @@ class PjRtCApiBuffer : public PjRtBuffer { // we set on `readiness_event` modifies `readiness_promise_`. std::shared_ptr::Promise> readiness_promise_; // Set and cached the first time layout() is called. - mutable std::optional layout_; + mutable std::shared_ptr layout_; // Set and cached the first time is_dynamic_dimension() is called. mutable std::optional> is_dynamic_dimension_; diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 26c777b1fdd4ef..0b1da9ef4660a1 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -1121,12 +1121,12 @@ class PjRtBuffer { return on_device_shape().dimensions(); } - // The on-device memory layout of this buffer. Returned via unique_ptr to make + // The on-device memory layout of this buffer. Returned via shared_ptr to make // memory management easier -- PjRtLayout is an abstract base class, so cannot // be easily copied. - virtual std::unique_ptr layout() const { + virtual std::shared_ptr layout() const { CHECK(on_device_shape().has_layout()); - return std::make_unique(on_device_shape().layout()); + return std::make_shared(on_device_shape().layout()); } // PjRtBuffers can either represent a single array buffer or a tuple of array diff --git a/xla/pjrt/pjrt_executable.cc b/xla/pjrt/pjrt_executable.cc index e2fa5e53f9bfee..def2f0edd24b8d 100644 --- a/xla/pjrt/pjrt_executable.cc +++ b/xla/pjrt/pjrt_executable.cc @@ -422,7 +422,7 @@ PjRtExecutable::GetOutputDimensions() const { return output_dimensions; } -absl::StatusOr>> +absl::StatusOr>> PjRtExecutable::GetParameterLayouts() const { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, GetHloModules()); @@ -439,7 +439,7 @@ PjRtExecutable::GetParameterLayouts() const { ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); TF_ASSIGN_OR_RETURN(std::vector layouts, comp_layout.FlattenedParameterLayouts()); - std::vector> result; + std::vector> result; result.reserve(layouts.size()); for (const Layout& layout : layouts) { result.push_back(std::make_unique(layout)); @@ -447,7 +447,7 @@ PjRtExecutable::GetParameterLayouts() const { return result; } -absl::StatusOr>> +absl::StatusOr>> PjRtExecutable::GetOutputLayouts() const { TF_ASSIGN_OR_RETURN(std::vector> hlo_modules, GetHloModules()); @@ -464,7 +464,7 @@ PjRtExecutable::GetOutputLayouts() const { ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout(); TF_ASSIGN_OR_RETURN(std::vector layouts, comp_layout.FlattenedResultLayouts()); - std::vector> result; + std::vector> result; result.reserve(layouts.size()); for (const Layout& layout : layouts) { result.push_back(std::make_unique(layout)); diff --git a/xla/pjrt/pjrt_executable.h b/xla/pjrt/pjrt_executable.h index 07715fe0dbae79..fc4f76ef4776a8 100644 --- a/xla/pjrt/pjrt_executable.h +++ b/xla/pjrt/pjrt_executable.h @@ -335,11 +335,11 @@ class PjRtExecutable { GetOutputDimensions() const; // Returns the layout of each input parameter. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const; // Returns the layout of each output. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const; // Returns a list of lists of memory kind strings for output. The returned diff --git a/xla/pjrt/pjrt_layout.h b/xla/pjrt/pjrt_layout.h index eea9b861690860..005881e4634849 100644 --- a/xla/pjrt/pjrt_layout.h +++ b/xla/pjrt/pjrt_layout.h @@ -100,9 +100,9 @@ class PjRtXlaLayout : public PjRtLayout { // TODO(b/327524065): make callers use PjRtLayout directly instead of assuming // an xla::Layout and get rid of this function. inline Layout GetXlaLayoutUnsafe( - const std::unique_ptr& pjrt_layout) { - PjRtXlaLayout* xla_layout = - tensorflow::down_cast(pjrt_layout.get()); + const std::shared_ptr& pjrt_layout) { + const PjRtXlaLayout* xla_layout = + tensorflow::down_cast(pjrt_layout.get()); CHECK(xla_layout != nullptr) << "Got unexpected layout type"; return xla_layout->xla_layout(); } diff --git a/xla/python/ifrt/array.h b/xla/python/ifrt/array.h index 2a4ff23b1fdb1d..e31a2600352324 100644 --- a/xla/python/ifrt/array.h +++ b/xla/python/ifrt/array.h @@ -76,7 +76,7 @@ class Array : public llvm::RTTIExtends { // The device memory layout for each shard of the Array. All shards are // assumed to have the same layout. Cannot be nullptr; implementations should // return UNIMPLEMENTED instead. - virtual absl::StatusOr> layout() const = 0; + virtual absl::StatusOr> layout() const = 0; // Breaks an array up into per-device arrays. This is the elimination // counterpart of `Client::AssembleArrayFromSingleDeviceArrays()`. diff --git a/xla/python/ifrt/client.h b/xla/python/ifrt/client.h index 441aa66781a462..01eab2f3492e9a 100644 --- a/xla/python/ifrt/client.h +++ b/xla/python/ifrt/client.h @@ -241,7 +241,7 @@ class Client : public llvm::RTTIExtends { // single-shard dimensions `dims`. // TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of // single-shard dimensions and device. - virtual absl::StatusOr> + virtual absl::StatusOr> GetDefaultLayoutForDevice(DType dtype, absl::Span dims, Device* device) const = 0; diff --git a/xla/python/ifrt/executable.h b/xla/python/ifrt/executable.h index 5332768c885b9c..9bf0128ed7e0b8 100644 --- a/xla/python/ifrt/executable.h +++ b/xla/python/ifrt/executable.h @@ -78,10 +78,10 @@ class Executable : public llvm::RTTIExtends { // Returns a list of output `OpSharding`. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Returns an `HloModule` (optimized) per partition. virtual absl::StatusOr>> @@ -187,10 +187,10 @@ class LoadedExecutable // Returns a list of output OpSharding. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Return an HloModule (optimized) per partition. virtual absl::StatusOr>> diff --git a/xla/python/ifrt/mock.cc b/xla/python/ifrt/mock.cc index d62646bf5b78ad..09cfa924e46e99 100644 --- a/xla/python/ifrt/mock.cc +++ b/xla/python/ifrt/mock.cc @@ -78,9 +78,10 @@ MockArray::MockArray(tsl::RCReference delegated) return delegated_->shared_ptr_sharding(); }); ON_CALL(*this, layout) - .WillByDefault([this]() -> absl::StatusOr> { - return delegated_->layout(); - }); + .WillByDefault( + [this]() -> absl::StatusOr> { + return delegated_->layout(); + }); ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_)) .WillByDefault([this](ArrayCopySemantics semantics) { return delegated_->DisassembleIntoSingleDeviceArrays(semantics); diff --git a/xla/python/ifrt/mock.h b/xla/python/ifrt/mock.h index 11ba98cc96326a..2009c048cbb588 100644 --- a/xla/python/ifrt/mock.h +++ b/xla/python/ifrt/mock.h @@ -76,7 +76,7 @@ class MockArray : public llvm::RTTIExtends { MOCK_METHOD(const Sharding&, sharding, (), (const, final)); MOCK_METHOD(absl::Nonnull>, shared_ptr_sharding, (), (const, final)); - MOCK_METHOD(absl::StatusOr>, layout, (), + MOCK_METHOD(absl::StatusOr>, layout, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, DisassembleIntoSingleDeviceArrays, (ArrayCopySemantics semantics), @@ -173,7 +173,7 @@ class MockClient : public llvm::RTTIExtends { MOCK_METHOD(absl::StatusOr>, GetTopologyForDevices, (const tsl::RCReference& devices), (const, final)); - MOCK_METHOD(absl::StatusOr>, + MOCK_METHOD(absl::StatusOr>, GetDefaultLayoutForDevice, (xla::ifrt::DType dtype, absl::Span dims, xla::ifrt::Device* device), @@ -264,9 +264,9 @@ class MockExecutable : public llvm::RTTIExtends { (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetHloModules, (), (const, final)); @@ -293,9 +293,9 @@ class MockLoadedExecutable (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetOutputMemoryKinds, (), (const, final)); diff --git a/xla/python/ifrt_proxy/client/array.h b/xla/python/ifrt_proxy/client/array.h index 2a9ccdf17bea32..5c4b42475f36c7 100644 --- a/xla/python/ifrt_proxy/client/array.h +++ b/xla/python/ifrt_proxy/client/array.h @@ -112,7 +112,7 @@ class Array final : public llvm::RTTIExtends { std::shared_ptr shared_ptr_sharding() const override { return sharding_; } - absl::StatusOr> layout() const override { + absl::StatusOr> layout() const override { return absl::UnimplementedError( "Array::layout() not implemented for IFRT proxy"); }; diff --git a/xla/python/ifrt_proxy/client/client.h b/xla/python/ifrt_proxy/client/client.h index 3732b5ddd832d7..0f1323e1abeaa9 100644 --- a/xla/python/ifrt_proxy/client/client.h +++ b/xla/python/ifrt_proxy/client/client.h @@ -140,9 +140,10 @@ class Client final : public llvm::RTTIExtends { return absl::UnimplementedError( "GetTopologyForDevices is not supported for the IFRT proxy client."); } - absl::StatusOr> GetDefaultLayoutForDevice( - xla::ifrt::DType dtype, absl::Span dims, - xla::ifrt::Device* device) const override { + absl::StatusOr> + GetDefaultLayoutForDevice(xla::ifrt::DType dtype, + absl::Span dims, + xla::ifrt::Device* device) const override { return absl::UnimplementedError( "GetDefaultLayout is not supported for the IFRT proxy client."); } diff --git a/xla/python/ifrt_proxy/client/executable.cc b/xla/python/ifrt_proxy/client/executable.cc index 81ef43ec5c0f3b..6de9e3757eeff3 100644 --- a/xla/python/ifrt_proxy/client/executable.cc +++ b/xla/python/ifrt_proxy/client/executable.cc @@ -310,10 +310,11 @@ LoadedExecutable::LoadedExecutable( auto parse_layouts = [](const LoadedExecutableMetadataResponse::LayoutList& list) { - std::vector layouts; + std::vector> layouts; layouts.reserve(list.layouts_size()); for (const auto& layout : list.layouts()) { - layouts.push_back(xla::Layout::CreateFromProto(layout)); + layouts.push_back(std::make_shared( + xla::Layout::CreateFromProto(layout))); } return layouts; }; @@ -433,34 +434,20 @@ std::optional> LoadedExecutable::GetOutputShardings() return (*info)->output_shardings; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetParameterLayouts() const { tsl::profiler::TraceMe traceme_ifrt_entrypoint( "IfrtProxyEntrypointLoadedExecutableGetParameterLayouts"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); - TF_RETURN_IF_ERROR(info->parameter_layouts.status()); - - std::vector> result; - result.reserve(info->parameter_layouts->size()); - for (const xla::Layout& layout : *info->parameter_layouts) { - result.push_back(std::make_unique(layout)); - } - return result; + return info->parameter_layouts; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetOutputLayouts() const { tsl::profiler::TraceMe traceme_ifrt_entrypoint( "IfrtProxyEntrypointLoadedExecutableGetOutputLayouts"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); - TF_RETURN_IF_ERROR(info->output_layouts.status()); - - std::vector> result; - result.reserve(info->output_layouts->size()); - for (const xla::Layout& layout : *info->output_layouts) { - result.push_back(std::make_unique(layout)); - } - return result; + return info->output_layouts; } absl::StatusOr>> diff --git a/xla/python/ifrt_proxy/client/executable.h b/xla/python/ifrt_proxy/client/executable.h index 5ce5292d5a76b8..0af4a14a3e80b6 100644 --- a/xla/python/ifrt_proxy/client/executable.h +++ b/xla/python/ifrt_proxy/client/executable.h @@ -35,6 +35,7 @@ #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" @@ -77,9 +78,9 @@ class LoadedExecutable final std::optional> GetParameterShardings() const override; std::optional> GetOutputShardings() const override; - absl::StatusOr>> + absl::StatusOr>> GetParameterLayouts() const override; - absl::StatusOr>> + absl::StatusOr>> GetOutputLayouts() const override; absl::StatusOr>> GetOutputMemoryKinds() const override; @@ -105,8 +106,10 @@ class LoadedExecutable final std::optional> parameter_shardings; std::optional> output_shardings; - absl::StatusOr> parameter_layouts; - absl::StatusOr> output_layouts; + absl::StatusOr>> + parameter_layouts; + absl::StatusOr>> + output_layouts; // Elements in `output_memory_kinds` point to elements in `memory_kinds`. // Required since `GetOutputMemoryKinds()` returns `absl::string_view`. diff --git a/xla/python/ifrt_proxy/client/executable_test.cc b/xla/python/ifrt_proxy/client/executable_test.cc index 70bb1791d3d8f6..3972429fb38147 100644 --- a/xla/python/ifrt_proxy/client/executable_test.cc +++ b/xla/python/ifrt_proxy/client/executable_test.cc @@ -158,19 +158,20 @@ TEST_F(LoadedExecutableTest, Metadata) { ASSERT_OK_AND_ASSIGN(auto parameter_layouts, executable.GetParameterLayouts()); EXPECT_EQ(parameter_layouts.size(), 2); + EXPECT_EQ(tensorflow::down_cast( + parameter_layouts[0].get()) + ->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1)); + EXPECT_EQ(tensorflow::down_cast( + parameter_layouts[1].get()) + ->xla_layout(), + xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); + ASSERT_OK_AND_ASSIGN(auto output_layouts, executable.GetOutputLayouts()); + EXPECT_EQ(output_layouts.size(), 1); EXPECT_EQ( - tensorflow::down_cast(parameter_layouts[0].get()) - ->xla_layout(), - xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1)); - EXPECT_EQ( - tensorflow::down_cast(parameter_layouts[1].get()) + tensorflow::down_cast(output_layouts[0].get()) ->xla_layout(), xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); - ASSERT_OK_AND_ASSIGN(auto output_layouts, executable.GetOutputLayouts()); - EXPECT_EQ(output_layouts.size(), 1); - EXPECT_EQ(tensorflow::down_cast(output_layouts[0].get()) - ->xla_layout(), - xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2)); EXPECT_THAT(executable.GetOutputMemoryKinds(), IsOkAndHolds(ElementsAre(ElementsAre("foo")))); } diff --git a/xla/python/ifrt_proxy/server/ifrt_backend.cc b/xla/python/ifrt_proxy/server/ifrt_backend.cc index e26a6cb5c44e5d..b36f84fabcacc8 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -1287,7 +1287,7 @@ IfrtBackend::HandleLoadedExecutableMetadataRequest( parameter_layouts.ok()) { auto* const layouts = metadata_resp->mutable_parameter_layouts_list()->mutable_layouts(); - for (const std::unique_ptr& parameter_layout : + for (const std::shared_ptr& parameter_layout : *parameter_layouts) { // TODO(b/329165105): use PjRtLayout::Serialize instead const xla::PjRtXlaLayout* layout = @@ -1305,7 +1305,7 @@ IfrtBackend::HandleLoadedExecutableMetadataRequest( output_layouts.ok()) { auto* const layouts = metadata_resp->mutable_output_layouts_list()->mutable_layouts(); - for (const std::unique_ptr& output_layout : + for (const std::shared_ptr& output_layout : *output_layouts) { // TODO(b/329165105): use PjRtLayout::Serialize instead const xla::PjRtXlaLayout* layout = diff --git a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index f3fa9f991ea056..fd3c35e6831f03 100644 --- a/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -1243,16 +1243,16 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { EXPECT_CALL(*executable, GetOutputShardings()) .WillOnce(Return(std::vector{op_sharding1})); - std::vector> parameter_layouts; - parameter_layouts.push_back(std::make_unique( + std::vector> parameter_layouts; + parameter_layouts.push_back(std::make_shared( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1))); - parameter_layouts.push_back(std::make_unique( + parameter_layouts.push_back(std::make_shared( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); EXPECT_CALL(*executable, GetParameterLayouts()) .WillOnce(Return(std::move(parameter_layouts))); - std::vector> output_layouts; - output_layouts.push_back(std::make_unique( + std::vector> output_layouts; + output_layouts.push_back(std::make_shared( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); EXPECT_CALL(*executable, GetOutputLayouts()) .WillOnce(Return(std::move(output_layouts))); diff --git a/xla/python/jax_jit.cc b/xla/python/jax_jit.cc index 46041be0e7eb8d..e6d7ee51ab5f1f 100644 --- a/xla/python/jax_jit.cc +++ b/xla/python/jax_jit.cc @@ -197,7 +197,7 @@ std::string CallSignature::DebugString() const { out->append(s.DebugString()); }; auto layout_formatter = [](std::string* out, - const std::shared_ptr& l) { + const std::shared_ptr& l) { if (l != nullptr) { out->append(l->ToString()); } else { @@ -252,8 +252,8 @@ bool CallSignature::operator==(const CallSignature& other) const { absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, ShardingEqual) && absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, - [](const std::shared_ptr& a, - const std::shared_ptr& b) { + [](const std::shared_ptr& a, + const std::shared_ptr& b) { return (a && b) ? *a == *b : a == b; }) && (global_extra_jit_context.has_value() == diff --git a/xla/python/jax_jit.h b/xla/python/jax_jit.h index 4fb3775ef823c0..59d35abf0daa18 100644 --- a/xla/python/jax_jit.h +++ b/xla/python/jax_jit.h @@ -196,7 +196,7 @@ struct CallSignature { std::vector dynamic_arg_shardings; // The layout of the jax.Array arguments. - std::vector> dynamic_arg_layouts; + std::vector> dynamic_arg_layouts; absl::InlinedVector committed_args; diff --git a/xla/python/pjrt_ifrt/basic_string_array.cc b/xla/python/pjrt_ifrt/basic_string_array.cc index d3b9fd1be984f5..14914090b5912d 100644 --- a/xla/python/pjrt_ifrt/basic_string_array.cc +++ b/xla/python/pjrt_ifrt/basic_string_array.cc @@ -147,6 +147,7 @@ BasicStringArray::BasicStringArray(Client* client, Shape shape, : client_(client), shape_(std::move(shape)), sharding_(std::move(sharding)), + layout_(std::make_shared()), buffers_(std::move(buffers)), ready_future_(std::move(ready_future)), on_done_with_buffer_(std::move(on_done_with_buffer)) {} @@ -446,12 +447,13 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( std::move(buffers_future), std::move(on_done_with_buffer)); } -absl::StatusOr> BasicStringArray::layout() const { +absl::StatusOr> BasicStringArray::layout() + const { absl::MutexLock lock(&mu_); if (is_deleted_) { return absl::FailedPreconditionError("Array has already been deleted"); } - return std::make_unique(); + return layout_; } std::string BasicStringArray::DebugString() const { diff --git a/xla/python/pjrt_ifrt/basic_string_array.h b/xla/python/pjrt_ifrt/basic_string_array.h index a430cfa73fdd26..b3c6ef0caf7e45 100644 --- a/xla/python/pjrt_ifrt/basic_string_array.h +++ b/xla/python/pjrt_ifrt/basic_string_array.h @@ -121,7 +121,7 @@ class BasicStringArray final return sharding_; } - absl::StatusOr> layout() const override; + absl::StatusOr> layout() const override; absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; @@ -172,6 +172,7 @@ class BasicStringArray final Client* client_; Shape shape_; std::shared_ptr sharding_; + std::shared_ptr layout_; Future buffers_; Future<> ready_future_; diff --git a/xla/python/pjrt_ifrt/pjrt_array.cc b/xla/python/pjrt_ifrt/pjrt_array.cc index 0c04f21a533464..724703bf47d207 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/xla/python/pjrt_ifrt/pjrt_array.cc @@ -553,7 +553,7 @@ bool PjRtArray::IsDeleted() const { std::string PjRtArray::DebugString() const { DCHECK(this); - absl::StatusOr> layout_ptr = layout(); + absl::StatusOr> layout_ptr = layout(); std::string layout_str = layout_ptr.ok() ? (*layout_ptr)->ToString() : ""; @@ -566,12 +566,12 @@ std::string PjRtArray::DebugString() const { // TODO(b/330198879): populate layout at construction instead of accessing PJRT // buffer directly for consistency with Pathways. -absl::StatusOr> PjRtArray::layout() const { +absl::StatusOr> PjRtArray::layout() const { CHECK(!pjrt_buffers_.empty()); - std::unique_ptr layout = pjrt_buffers_[0]->layout(); + std::shared_ptr layout = pjrt_buffers_[0]->layout(); #ifndef NDEBUG for (int i = 1; i < pjrt_buffers_.size(); ++i) { - std::unique_ptr layout_i = pjrt_buffers_[i]->layout(); + std::shared_ptr layout_i = pjrt_buffers_[i]->layout(); DCHECK(*layout == *layout_i) << "PjRtArray has mismatched layouts across shards! " << "shard 0: " << layout->ToString() << ", shard " << i << ": " diff --git a/xla/python/pjrt_ifrt/pjrt_array.h b/xla/python/pjrt_ifrt/pjrt_array.h index d14747fea550ea..7a88f708248393 100644 --- a/xla/python/pjrt_ifrt/pjrt_array.h +++ b/xla/python/pjrt_ifrt/pjrt_array.h @@ -151,7 +151,7 @@ class PjRtArray final return sharding_; } - absl::StatusOr> layout() const override; + absl::StatusOr> layout() const override; absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index dca9f6381e2e45..171adfa6e9b10e 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -1116,7 +1116,7 @@ absl::StatusOr> PjRtClient::GetTopologyForDevices( topology)); } -absl::StatusOr> +absl::StatusOr> PjRtClient::GetDefaultLayoutForDevice(DType dtype, absl::Span dims, Device* device) const { diff --git a/xla/python/pjrt_ifrt/pjrt_client.h b/xla/python/pjrt_ifrt/pjrt_client.h index 4849f5329e9e07..3f87a7139bddb2 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.h +++ b/xla/python/pjrt_ifrt/pjrt_client.h @@ -259,9 +259,9 @@ class PjRtClient final absl::StatusOr> GetTopologyForDevices( const tsl::RCReference& devices) const override; - absl::StatusOr> GetDefaultLayoutForDevice( - DType dtype, absl::Span dims, - Device* device) const override; + absl::StatusOr> + GetDefaultLayoutForDevice(DType dtype, absl::Span dims, + Device* device) const override; absl::StatusOr LookupPjRtDevice( xla::PjRtDevice* pjrt_device) const override; diff --git a/xla/python/pjrt_ifrt/pjrt_executable.h b/xla/python/pjrt_ifrt/pjrt_executable.h index ce83ee0da24de1..cb75494a5a4599 100644 --- a/xla/python/pjrt_ifrt/pjrt_executable.h +++ b/xla/python/pjrt_ifrt/pjrt_executable.h @@ -116,13 +116,13 @@ class PjRtExecutable final return pjrt_executable_->GetOutputShardings(); } - absl::StatusOr>> + absl::StatusOr>> GetParameterLayouts() const override { DCHECK(this); return pjrt_executable_->GetParameterLayouts(); } - absl::StatusOr>> + absl::StatusOr>> GetOutputLayouts() const override { DCHECK(this); return pjrt_executable_->GetOutputLayouts(); @@ -242,13 +242,13 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->GetOutputShardings(); } - absl::StatusOr>> + absl::StatusOr>> GetParameterLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetParameterLayouts(); } - absl::StatusOr>> + absl::StatusOr>> GetOutputLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetOutputLayouts(); diff --git a/xla/python/py_array.cc b/xla/python/py_array.cc index a8899b8ea144fe..e917dc3e4294dd 100644 --- a/xla/python/py_array.cc +++ b/xla/python/py_array.cc @@ -47,6 +47,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep diff --git a/xla/python/py_array.h b/xla/python/py_array.h index 61987eb985e003..d3bf0ca3337966 100644 --- a/xla/python/py_array.h +++ b/xla/python/py_array.h @@ -171,7 +171,7 @@ class PyArray : public nanobind::object { const nanobind::object& sharding() const { return GetStorage().sharding; } - absl::StatusOr> layout() { + absl::StatusOr> layout() { return ifrt_array()->layout(); } diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc index f900fe09170092..6d9cf48173aaff 100644 --- a/xla/python/py_client.cc +++ b/xla/python/py_client.cc @@ -777,7 +777,8 @@ PyType_Slot PyClient::slots_[] = { .def( "get_default_layout", [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, - nb_class_ptr device) -> std::unique_ptr { + nb_class_ptr device) + -> std::shared_ptr { ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); std::vector dims = SequenceToVector(shard_shape); return xla::ValueOrThrow( diff --git a/xla/python/py_compile_only_client.cc b/xla/python/py_compile_only_client.cc index d366ef93c096bf..a31e732a84ee11 100644 --- a/xla/python/py_compile_only_client.cc +++ b/xla/python/py_compile_only_client.cc @@ -336,7 +336,7 @@ class CompileOnlyIfRtClient final return topology_; } - absl::StatusOr> GetDefaultLayoutForDevice( + absl::StatusOr> GetDefaultLayoutForDevice( ifrt::DType dtype, absl::Span dims, ifrt::Device* device) const override { TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype)); diff --git a/xla/python/py_executable.cc b/xla/python/py_executable.cc index 7326521695c7bc..bd582d3035cf58 100644 --- a/xla/python/py_executable.cc +++ b/xla/python/py_executable.cc @@ -415,13 +415,13 @@ PyLoadedExecutable::GetOutputMemoryKinds() const { return ifrt_loaded_executable_->GetOutputMemoryKinds(); } -absl::StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetParameterLayouts() const { nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetParameterLayouts(); } -absl::StatusOr>> +absl::StatusOr>> PyLoadedExecutable::GetOutputLayouts() const { nb::gil_scoped_release gil_release; return ifrt_loaded_executable_->GetOutputLayouts(); diff --git a/xla/python/py_executable.h b/xla/python/py_executable.h index 9af7a4a7839702..480f33d99d95a9 100644 --- a/xla/python/py_executable.h +++ b/xla/python/py_executable.h @@ -189,11 +189,11 @@ class PyLoadedExecutable { absl::StatusOr>> GetOutputMemoryKinds() const; - absl::StatusOr>> GetParameterLayouts() - const; + absl::StatusOr>> + GetParameterLayouts() const; - absl::StatusOr>> GetOutputLayouts() - const; + absl::StatusOr>> + GetOutputLayouts() const; std::optional> GetParameterShardings() const; diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 023252fd8c690b..3101f288cf6775 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -1307,13 +1307,13 @@ FunctionalHloRunner::CopyArgumentsToDevice( TF_RET_CHECK(!shape.IsTuple()) << "Param tuple without flattened_arguments"; return non_tuple_memory_space(shape); }; - TF_ASSIGN_OR_RETURN(const std::vector>& + TF_ASSIGN_OR_RETURN(const std::vector>& executable_parameter_pjrt_layouts, executable->GetParameterLayouts()); std::vector executable_parameter_layouts; executable_parameter_layouts.reserve( executable_parameter_pjrt_layouts.size()); - for (const std::unique_ptr& pjrt_layout : + for (const std::shared_ptr& pjrt_layout : executable_parameter_pjrt_layouts) { executable_parameter_layouts.push_back( xla::GetXlaLayoutUnsafe(pjrt_layout));