Skip to content

Commit

Permalink
Return std::shared_ptr<const xla::PjRtLayout> from IFRT and PjRt in…
Browse files Browse the repository at this point in the history
…stead of `std::unique_ptr<xla::PjRtLayout>`

PiperOrigin-RevId: 711892970
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Jan 4, 2025
1 parent 95ac9e3 commit c358bd9
Show file tree
Hide file tree
Showing 37 changed files with 129 additions and 135 deletions.
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ PJRT_Error* PJRT_Buffer_GetMemoryLayout(
absl::MutexLock lock(&args->buffer->mu);
if (!layout_data.has_value()) {
// TODO(skyewm): change PJRT C API to also use opaque layout type
std::unique_ptr<xla::PjRtLayout> pjrt_layout =
std::shared_ptr<const xla::PjRtLayout> pjrt_layout =
args->buffer->buffer->layout();
xla::PjRtXlaLayout* pjrt_xla_layout =
tensorflow::down_cast<xla::PjRtXlaLayout*>(pjrt_layout.get());
Expand Down
11 changes: 6 additions & 5 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2020,16 +2020,17 @@ absl::Span<const int64_t> PjRtCApiBuffer::dimensions() const {
return absl::Span<const int64_t>(args.dims, args.num_dims);
}

std::unique_ptr<PjRtLayout> PjRtCApiBuffer::layout() const {
std::shared_ptr<const PjRtLayout> PjRtCApiBuffer::layout() const {
{
absl::MutexLock lock(&mu_);
if (!layout_.has_value()) {
if (layout_ == nullptr) {
const PJRT_Api* c_api = pjrt_c_api();
PJRT_Layouts_Extension* extension =
pjrt::FindExtension<PJRT_Layouts_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts);
if (extension == nullptr) {
layout_.emplace(LayoutUtil::MakeDescendingLayout(dimensions().size()));
layout_ = std::make_shared<PjRtXlaLayout>(
LayoutUtil::MakeDescendingLayout(dimensions().size()));
} else {
std::unique_ptr<PJRT_Layouts_MemoryLayout,
pjrt::PJRT_Layouts_MemoryLayoutDeleter>
Expand Down Expand Up @@ -2057,11 +2058,11 @@ std::unique_ptr<PjRtLayout> PjRtCApiBuffer::layout() const {
absl::StatusOr<PjRtXlaLayout> pjrt_xla_layout =
PjRtXlaLayout::Deserialize(serialized_layout);
TF_CHECK_OK(pjrt_xla_layout.status());
layout_.emplace(*pjrt_xla_layout);
layout_ = std::make_shared<PjRtXlaLayout>(*std::move(pjrt_xla_layout));
}
}
}
return std::make_unique<PjRtXlaLayout>(*layout_);
return layout_;
}

bool PjRtCApiBuffer::has_dynamic_dimensions() const {
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/pjrt_c_api_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class PjRtCApiBuffer : public PjRtBuffer {

absl::Span<const int64_t> dimensions() const override;

std::unique_ptr<PjRtLayout> layout() const override;
std::shared_ptr<const PjRtLayout> layout() const override;

// PJRT C API doesn't support tuple buffers.
bool IsTuple() const override { return false; }
Expand Down Expand Up @@ -583,7 +583,7 @@ class PjRtCApiBuffer : public PjRtBuffer {
// we set on `readiness_event` modifies `readiness_promise_`.
std::shared_ptr<PjRtFuture<>::Promise> readiness_promise_;
// Set and cached the first time layout() is called.
mutable std::optional<PjRtXlaLayout> layout_;
mutable std::shared_ptr<const PjRtXlaLayout> layout_;
// Set and cached the first time is_dynamic_dimension() is called.
mutable std::optional<absl::InlinedVector<bool, InlineRank()>>
is_dynamic_dimension_;
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -1121,12 +1121,12 @@ class PjRtBuffer {
return on_device_shape().dimensions();
}

// The on-device memory layout of this buffer. Returned via unique_ptr to make
// The on-device memory layout of this buffer. Returned via shared_ptr to make
// memory management easier -- PjRtLayout is an abstract base class, so cannot
// be easily copied.
virtual std::unique_ptr<PjRtLayout> layout() const {
virtual std::shared_ptr<const PjRtLayout> layout() const {
CHECK(on_device_shape().has_layout());
return std::make_unique<PjRtXlaLayout>(on_device_shape().layout());
return std::make_shared<PjRtXlaLayout>(on_device_shape().layout());
}

// PjRtBuffers can either represent a single array buffer or a tuple of array
Expand Down Expand Up @@ -1249,7 +1249,7 @@ class PjRtBuffer {
} else {
device_shape = ShapeUtil::MakeShape(element_type(), literal_dims);
// TODO(b/327524065): use PjRtLayout directly instead of xla::Layout
*device_shape.mutable_layout() = GetXlaLayoutUnsafe(layout());
*device_shape.mutable_layout() = GetXlaLayoutUnsafe(*layout());
}
} else {
// TODO(skyewm): does anything need to create tuple literals? The PJRT C
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ PjRtExecutable::GetOutputDimensions() const {
return output_dimensions;
}

absl::StatusOr<std::vector<std::unique_ptr<PjRtLayout>>>
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
PjRtExecutable::GetParameterLayouts() const {
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<HloModule>> hlo_modules,
GetHloModules());
Expand All @@ -439,15 +439,15 @@ PjRtExecutable::GetParameterLayouts() const {
ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout();
TF_ASSIGN_OR_RETURN(std::vector<Layout> layouts,
comp_layout.FlattenedParameterLayouts());
std::vector<std::unique_ptr<PjRtLayout>> result;
std::vector<std::shared_ptr<const PjRtLayout>> result;
result.reserve(layouts.size());
for (const Layout& layout : layouts) {
result.push_back(std::make_unique<PjRtXlaLayout>(layout));
}
return result;
}

absl::StatusOr<std::vector<std::unique_ptr<PjRtLayout>>>
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
PjRtExecutable::GetOutputLayouts() const {
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<HloModule>> hlo_modules,
GetHloModules());
Expand All @@ -464,7 +464,7 @@ PjRtExecutable::GetOutputLayouts() const {
ComputationLayout comp_layout = hlo_modules[0]->entry_computation_layout();
TF_ASSIGN_OR_RETURN(std::vector<Layout> layouts,
comp_layout.FlattenedResultLayouts());
std::vector<std::unique_ptr<PjRtLayout>> result;
std::vector<std::shared_ptr<const PjRtLayout>> result;
result.reserve(layouts.size());
for (const Layout& layout : layouts) {
result.push_back(std::make_unique<PjRtXlaLayout>(layout));
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/pjrt_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,11 @@ class PjRtExecutable {
GetOutputDimensions() const;

// Returns the layout of each input parameter.
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtLayout>>>
virtual absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
GetParameterLayouts() const;

// Returns the layout of each output.
virtual absl::StatusOr<std::vector<std::unique_ptr<PjRtLayout>>>
virtual absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
GetOutputLayouts() const;

// Returns a list of lists of memory kind strings for output. The returned
Expand Down
8 changes: 3 additions & 5 deletions xla/pjrt/pjrt_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef XLA_PJRT_PJRT_LAYOUT_H_
#define XLA_PJRT_PJRT_LAYOUT_H_

#include <memory>
#include <string>
#include <utility>

Expand Down Expand Up @@ -99,10 +98,9 @@ class PjRtXlaLayout : public PjRtLayout {

// TODO(b/327524065): make callers use PjRtLayout directly instead of assuming
// an xla::Layout and get rid of this function.
inline Layout GetXlaLayoutUnsafe(
const std::unique_ptr<PjRtLayout>& pjrt_layout) {
PjRtXlaLayout* xla_layout =
tensorflow::down_cast<PjRtXlaLayout*>(pjrt_layout.get());
inline Layout GetXlaLayoutUnsafe(const PjRtLayout& pjrt_layout) {
const PjRtXlaLayout* xla_layout =
tensorflow::down_cast<const PjRtXlaLayout*>(&pjrt_layout);
CHECK(xla_layout != nullptr) << "Got unexpected layout type";
return xla_layout->xla_layout();
}
Expand Down
2 changes: 1 addition & 1 deletion xla/python/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ absl::StatusOr<nb::capsule> BufferToDLPackManagedTensor(
pjrt_buffer->dimensions().end());

// TODO(b/327524065): use PjRtLayout directly instead of xla::Layout
Layout xla_layout = GetXlaLayoutUnsafe(pjrt_buffer->layout());
Layout xla_layout = GetXlaLayoutUnsafe(*pjrt_buffer->layout());
pack->strides = StridesForShape(pjrt_buffer->element_type(),
pjrt_buffer->dimensions(), xla_layout);

Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Array : public llvm::RTTIExtends<Array, Value> {
// The device memory layout for each shard of the Array. All shards are
// assumed to have the same layout. Cannot be nullptr; implementations should
// return UNIMPLEMENTED instead.
virtual absl::StatusOr<std::unique_ptr<PjRtLayout>> layout() const = 0;
virtual absl::StatusOr<std::shared_ptr<const PjRtLayout>> layout() const = 0;

// Breaks an array up into per-device arrays. This is the elimination
// counterpart of `Client::AssembleArrayFromSingleDeviceArrays()`.
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
// single-shard dimensions `dims`.
// TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of
// single-shard dimensions and device.
virtual absl::StatusOr<std::unique_ptr<xla::PjRtLayout>>
virtual absl::StatusOr<std::shared_ptr<const PjRtLayout>>
GetDefaultLayoutForDevice(DType dtype, absl::Span<const int64_t> dims,
Device* device) const = 0;

Expand Down
8 changes: 4 additions & 4 deletions xla/python/ifrt/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ class Executable : public llvm::RTTIExtends<Executable, llvm::RTTIRoot> {
// Returns a list of output `OpSharding`.
virtual std::optional<std::vector<OpSharding>> GetOutputShardings() const = 0;
// Returns a list of parameter layouts.
virtual absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
virtual absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetParameterLayouts() const = 0;
// Returns a list of output/result layouts.
virtual absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
virtual absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetOutputLayouts() const = 0;
// Returns an `HloModule` (optimized) per partition.
virtual absl::StatusOr<std::vector<std::shared_ptr<HloModule>>>
Expand Down Expand Up @@ -187,10 +187,10 @@ class LoadedExecutable
// Returns a list of output OpSharding.
virtual std::optional<std::vector<OpSharding>> GetOutputShardings() const = 0;
// Returns a list of parameter layouts.
virtual absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
virtual absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetParameterLayouts() const = 0;
// Returns a list of output/result layouts.
virtual absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
virtual absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetOutputLayouts() const = 0;
// Return an HloModule (optimized) per partition.
virtual absl::StatusOr<std::vector<std::shared_ptr<HloModule>>>
Expand Down
7 changes: 4 additions & 3 deletions xla/python/ifrt/mock.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ MockArray::MockArray(tsl::RCReference<xla::ifrt::Array> delegated)
return delegated_->shared_ptr_sharding();
});
ON_CALL(*this, layout)
.WillByDefault([this]() -> absl::StatusOr<std::unique_ptr<PjRtLayout>> {
return delegated_->layout();
});
.WillByDefault(
[this]() -> absl::StatusOr<std::shared_ptr<const PjRtLayout>> {
return delegated_->layout();
});
ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_))
.WillByDefault([this](ArrayCopySemantics semantics) {
return delegated_->DisassembleIntoSingleDeviceArrays(semantics);
Expand Down
12 changes: 6 additions & 6 deletions xla/python/ifrt/mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class MockArray : public llvm::RTTIExtends<MockArray, Array> {
MOCK_METHOD(const Sharding&, sharding, (), (const, final));
MOCK_METHOD(absl::Nonnull<std::shared_ptr<const Sharding>>,
shared_ptr_sharding, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<PjRtLayout>>, layout, (),
MOCK_METHOD(absl::StatusOr<std::shared_ptr<const PjRtLayout>>, layout, (),
(const, final));
MOCK_METHOD(absl::StatusOr<std::vector<tsl::RCReference<Array>>>,
DisassembleIntoSingleDeviceArrays, (ArrayCopySemantics semantics),
Expand Down Expand Up @@ -173,7 +173,7 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
MOCK_METHOD(absl::StatusOr<std::shared_ptr<Topology>>, GetTopologyForDevices,
(const tsl::RCReference<xla::ifrt::DeviceList>& devices),
(const, final));
MOCK_METHOD(absl::StatusOr<std::unique_ptr<xla::PjRtLayout>>,
MOCK_METHOD(absl::StatusOr<std::shared_ptr<const PjRtLayout>>,
GetDefaultLayoutForDevice,
(xla::ifrt::DType dtype, absl::Span<const int64_t> dims,
xla::ifrt::Device* device),
Expand Down Expand Up @@ -264,9 +264,9 @@ class MockExecutable : public llvm::RTTIExtends<MockExecutable, Executable> {
(const, final));
MOCK_METHOD(std::optional<std::vector<OpSharding>>, GetOutputShardings, (),
(const, final));
MOCK_METHOD(absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>,
MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>,
GetParameterLayouts, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>,
MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>,
GetOutputLayouts, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<HloModule>>>,
GetHloModules, (), (const, final));
Expand All @@ -293,9 +293,9 @@ class MockLoadedExecutable
(const, final));
MOCK_METHOD(std::optional<std::vector<OpSharding>>, GetOutputShardings, (),
(const, final));
MOCK_METHOD(absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>,
MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>,
GetParameterLayouts, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>,
MOCK_METHOD(absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>,
GetOutputLayouts, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::vector<std::vector<absl::string_view>>>,
GetOutputMemoryKinds, (), (const, final));
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt_proxy/client/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
std::shared_ptr<const Sharding> shared_ptr_sharding() const override {
return sharding_;
}
absl::StatusOr<std::unique_ptr<PjRtLayout>> layout() const override {
absl::StatusOr<std::shared_ptr<const PjRtLayout>> layout() const override {
return absl::UnimplementedError(
"Array::layout() not implemented for IFRT proxy");
};
Expand Down
7 changes: 4 additions & 3 deletions xla/python/ifrt_proxy/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ class Client final : public llvm::RTTIExtends<Client, xla::ifrt::Client> {
return absl::UnimplementedError(
"GetTopologyForDevices is not supported for the IFRT proxy client.");
}
absl::StatusOr<std::unique_ptr<xla::PjRtLayout>> GetDefaultLayoutForDevice(
xla::ifrt::DType dtype, absl::Span<const int64_t> dims,
xla::ifrt::Device* device) const override {
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>>
GetDefaultLayoutForDevice(xla::ifrt::DType dtype,
absl::Span<const int64_t> dims,
xla::ifrt::Device* device) const override {
return absl::UnimplementedError(
"GetDefaultLayout is not supported for the IFRT proxy client.");
}
Expand Down
27 changes: 7 additions & 20 deletions xla/python/ifrt_proxy/client/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,11 @@ LoadedExecutable::LoadedExecutable(

auto parse_layouts =
[](const LoadedExecutableMetadataResponse::LayoutList& list) {
std::vector<xla::Layout> layouts;
std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
layouts.reserve(list.layouts_size());
for (const auto& layout : list.layouts()) {
layouts.push_back(xla::Layout::CreateFromProto(layout));
layouts.push_back(std::make_shared<xla::PjRtXlaLayout>(
xla::Layout::CreateFromProto(layout)));
}
return layouts;
};
Expand Down Expand Up @@ -433,34 +434,20 @@ std::optional<std::vector<OpSharding>> LoadedExecutable::GetOutputShardings()
return (*info)->output_shardings;
}

absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
LoadedExecutable::GetParameterLayouts() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetParameterLayouts");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
TF_RETURN_IF_ERROR(info->parameter_layouts.status());

std::vector<std::unique_ptr<xla::PjRtLayout>> result;
result.reserve(info->parameter_layouts->size());
for (const xla::Layout& layout : *info->parameter_layouts) {
result.push_back(std::make_unique<xla::PjRtXlaLayout>(layout));
}
return result;
return info->parameter_layouts;
}

absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
LoadedExecutable::GetOutputLayouts() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputLayouts");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
TF_RETURN_IF_ERROR(info->output_layouts.status());

std::vector<std::unique_ptr<xla::PjRtLayout>> result;
result.reserve(info->output_layouts->size());
for (const xla::Layout& layout : *info->output_layouts) {
result.push_back(std::make_unique<xla::PjRtXlaLayout>(layout));
}
return result;
return info->output_layouts;
}

absl::StatusOr<std::vector<std::vector<absl::string_view>>>
Expand Down
11 changes: 7 additions & 4 deletions xla/python/ifrt_proxy/client/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "xla/hlo/ir/hlo_module.h"
#include "xla/layout.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/client.h"
Expand Down Expand Up @@ -77,9 +78,9 @@ class LoadedExecutable final

std::optional<std::vector<OpSharding>> GetParameterShardings() const override;
std::optional<std::vector<OpSharding>> GetOutputShardings() const override;
absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetParameterLayouts() const override;
absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetOutputLayouts() const override;
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
GetOutputMemoryKinds() const override;
Expand All @@ -105,8 +106,10 @@ class LoadedExecutable final
std::optional<std::vector<xla::OpSharding>> parameter_shardings;
std::optional<std::vector<xla::OpSharding>> output_shardings;

absl::StatusOr<std::vector<xla::Layout>> parameter_layouts;
absl::StatusOr<std::vector<xla::Layout>> output_layouts;
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
parameter_layouts;
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts;

// Elements in `output_memory_kinds` point to elements in `memory_kinds`.
// Required since `GetOutputMemoryKinds()` returns `absl::string_view`.
Expand Down
Loading

0 comments on commit c358bd9

Please sign in to comment.