Skip to content

Commit

Permalink
[XLA:CPU] Return error when trying to create a view of an unaligned b…
Browse files Browse the repository at this point in the history
…uffer.

PiperOrigin-RevId: 681431840
  • Loading branch information
Adam-Banas authored and Google-ML-Automation committed Oct 2, 2024
1 parent f9822ba commit 18a8af1
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ cc_library(
hdrs = ["pjrt_client_test.h"],
deps = [
":pjrt_client",
":pjrt_compiler",
"//xla:shape_util",
"//xla:test",
"//xla:xla_data_proto_cc",
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ cc_library(
":cpu_topology",
":tracked_tfrt_cpu_device_buffer",
"//xla:array",
"//xla:cpu_function_runtime",
"//xla:debug_options_flags",
"//xla:executable_run_options",
"//xla:literal",
Expand Down Expand Up @@ -192,6 +193,7 @@ cc_library(
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
14 changes: 14 additions & 0 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
Expand All @@ -52,6 +53,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/thunk_executor.h"
#include "xla/client/executable_build_options.h"
#include "xla/client/xla_computation.h"
#include "xla/cpu_function_runtime.h"
#include "xla/debug_options_flags.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/hlo_computation.h"
Expand Down Expand Up @@ -867,6 +869,11 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
return Compile(xla_computation, options);
}

static bool is_aligned_data(void* ptr) {
return (absl::bit_cast<std::uintptr_t>(ptr) &
(cpu_function_runtime::MinAlign() - 1)) == 0;
}

absl::StatusOr<std::unique_ptr<PjRtBuffer>>
TfrtCpuClient::CreateViewOfDeviceBuffer(
void* device_ptr, const Shape& shape, PjRtDevice* device,
Expand All @@ -877,6 +884,13 @@ TfrtCpuClient::CreateViewOfDeviceBuffer(
"TfrtCpuClient::CreateViewOfDeviceBuffer does not support `stream` "
"argument.");
}
if (!is_aligned_data(device_ptr)) {
VLOG(1) << "Can't create a view of buffer with unaligned data, ptr: "
<< device_ptr << " is not aligned to "
<< cpu_function_runtime::MinAlign() << " bytes.";
return InvalidArgument(
"Can't create a view of buffer with unaligned data.");
}
absl::InlinedVector<tsl::AsyncValueRef<MaybeOwningCpuMemory>, 4> buffers;
size_t byte_size = ShapeUtil::ByteSizeOf(shape);
auto non_owning_buffer =
Expand Down
29 changes: 29 additions & 0 deletions xla/pjrt/pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -486,6 +487,34 @@ TEST(PjRtClientTest, CopyToDeviceAsyncExternalCpuOnly) {
}
}

TEST(PjRtClientTest, CreateViewOfUnalignedBufferReturnsErrorCpuOnly) {
TF_ASSERT_OK_AND_ASSIGN(auto client, GetClient());
ASSERT_GT(client->addressable_devices().size(), 1);

// Skip non-CPU platforms.
if (client->platform_id() != CpuId()) return;

std::vector<int32_t> data(5, 0);
auto* data_ptr = data.data();

// Pointer to the second element is always unaligned, because it's shifted by
// 4 bytes (size of int32_t) from the original pointer.
auto* unaligned_ptr = data_ptr + 1;

// Shape with a size smaller than the original data vector, because the
// 'unaligned_ptr' points to the second element.
Shape shape = ShapeUtil::MakeShape(S32, {4});

// Attempt to create a view of the unaligned buffer. Expect an error.
auto result = client->CreateViewOfDeviceBuffer(
unaligned_ptr, shape, client->addressable_devices()[0],
/*on_delete_callback=*/std::function<void()>());

ASSERT_FALSE(result.ok());
EXPECT_THAT(result.status().message(),
::testing::HasSubstr("unaligned data"));
}

absl::StatusOr<std::unique_ptr<PjRtBuffer>> MakeFloatBuffer(
PjRtClient* client, const std::vector<float>& data,
absl::Span<const int64_t> dimensions) {
Expand Down

0 comments on commit 18a8af1

Please sign in to comment.