diff --git a/xla/pjrt/BUILD b/xla/pjrt/BUILD index f443bdf46d1b9..dd9292c6e676f 100644 --- a/xla/pjrt/BUILD +++ b/xla/pjrt/BUILD @@ -242,6 +242,7 @@ cc_library( hdrs = ["pjrt_client_test.h"], deps = [ ":pjrt_client", + ":pjrt_compiler", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index 7c928106ee92a..9da4e491dab05 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -140,6 +140,7 @@ cc_library( ":cpu_topology", ":tracked_tfrt_cpu_device_buffer", "//xla:array", + "//xla:cpu_function_runtime", "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:literal", @@ -192,6 +193,7 @@ cc_library( "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index def5ea3aaab7b..4e96e1d2f39c7 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -31,6 +31,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/casts.h" #include "absl/base/dynamic_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" @@ -52,6 +53,7 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" +#include "xla/cpu_function_runtime.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_computation.h" @@ -867,6 +869,11 @@ absl::StatusOr> TfrtCpuClient::Compile( return Compile(xla_computation, options); } +static bool is_aligned_data(void* ptr) { + return (absl::bit_cast(ptr) & + (cpu_function_runtime::MinAlign() - 1)) == 0; +} + absl::StatusOr> TfrtCpuClient::CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, @@ -877,6 +884,13 @@ TfrtCpuClient::CreateViewOfDeviceBuffer( "TfrtCpuClient::CreateViewOfDeviceBuffer does not support `stream` " "argument."); } + if (!is_aligned_data(device_ptr)) { + VLOG(1) << "Can't create a view of buffer with unaligned data, ptr: " + << device_ptr << " is not aligned to " + << cpu_function_runtime::MinAlign() << " bytes."; + return InvalidArgument( + "Can't create a view of buffer with unaligned data."); + } absl::InlinedVector, 4> buffers; size_t byte_size = ShapeUtil::ByteSizeOf(shape); auto non_owning_buffer = diff --git a/xla/pjrt/pjrt_client_test.cc b/xla/pjrt/pjrt_client_test.cc index cdaadf57295ca..abe38a1ce0259 100644 --- a/xla/pjrt/pjrt_client_test.cc +++ b/xla/pjrt/pjrt_client_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -486,6 +487,34 @@ TEST(PjRtClientTest, CopyToDeviceAsyncExternalCpuOnly) { } } +TEST(PjRtClientTest, CreateViewOfUnalignedBufferReturnsErrorCpuOnly) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetClient()); + ASSERT_GT(client->addressable_devices().size(), 1); + + // Skip non-CPU platforms. + if (client->platform_id() != CpuId()) return; + + std::vector data(5, 0); + auto* data_ptr = data.data(); + + // Pointer to the second element is always unaligned, because it's shifted by + // 4 bytes (size of int32_t) from the original pointer. + auto* unaligned_ptr = data_ptr + 1; + + // Shape with a size smaller than the original data vector, because the + // 'unaligned_ptr' points to the second element. + Shape shape = ShapeUtil::MakeShape(S32, {4}); + + // Attempt to create a view of the unaligned buffer. Expect an error. + auto result = client->CreateViewOfDeviceBuffer( + unaligned_ptr, shape, client->addressable_devices()[0], + /*on_delete_callback=*/std::function()); + + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), + ::testing::HasSubstr("unaligned data")); +} + absl::StatusOr> MakeFloatBuffer( PjRtClient* client, const std::vector& data, absl::Span dimensions) {