From c5346ef5527edcddf194e6d33bcde082f544e8b1 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Thu, 12 Sep 2024 17:34:27 -0700 Subject: [PATCH] PR #16921: [PJRT:GPU] Treat GPU collective memory space as device memory space Imported from GitHub PR https://github.com/openxla/xla/pull/16921 This is a regression fix when using --xla_gpu_enable_nccl_user_buffers=true. Return device memory space when collective memory space is used as an output on GPU. Copybara import of the project: -- 8113e6fbe23d5902ecdd406793555c602c1b7f81 by Jane Liu : Treat collective memory space as device memory space when using as an output -- b5e43d6455adc49f5ac99a9a9e95cf495eb46170 by Jane Liu : fix the test Merging this change closes #16921 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/16921 from zhenying-liu:nccl-buffer-output b5e43d6455adc49f5ac99a9a9e95cf495eb46170 PiperOrigin-RevId: 674073363 --- xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 62 +++++++++++++++++++++++++ xla/pjrt/pjrt_stream_executor_client.cc | 2 + 2 files changed, 64 insertions(+) diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index c4d82ce9d3c4b..ac6f85bd4b950 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -1131,6 +1131,24 @@ constexpr char const* kD2HProgramTupleOutput = R"( } )"; +constexpr char const* kCollectiveMemorySpaceOutput = R"( + + HloModule jit__psum, entry_computation_layout={(s32[1,4]{1,0})->s32[4]{0}} + + region_0.3 { + Arg_0.0 = s32[] parameter(0) + Arg_1.0 = s32[] parameter(1) + ROOT add.0 = s32[] add(Arg_0.0, Arg_1.0) + } + + ENTRY main.10_spmd { + param = s32[1,4]{1,0} parameter(0) + reshape = s32[4]{0} reshape(param) + ROOT all-reduce = s32[4]{0} all-reduce(reshape), channel_id=1, to_apply=region_0.3 + } + +)"; + } // namespace TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) { @@ -1197,6 +1215,50 @@ TEST(StreamExecutorGpuClientTest, ExecutablePinnedHostOutputMemoryKindTest) { EXPECT_EQ(memory_kinds[0][0], "pinned_host"); } +// Verify the output device memory kind with collective memory space shape when +// NCCL user buffer is enabled. +TEST(StreamExecutorGpuClientTest, + ExecutableCollectiveMemoryOutputMemoryKindTest) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + xla::CompileOptions options; + options.executable_build_options.mutable_debug_options() + ->set_xla_gpu_enable_nccl_user_buffers(true); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CompileExecutable(kCollectiveMemorySpaceOutput, *client, options)); + std::vector data{1, 2, 3, 4}; + // Build the input shape with the correct memory space set. + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1, 4}, + /*major_to_minor=*/{1, 0}); + shape.mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); + + auto device = client->addressable_devices()[0]; + TF_EXPECT_OK(device->default_memory_space()); + TF_ASSERT_OK_AND_ASSIGN( + auto input, client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, device)); + EXPECT_EQ(input->memory_space()->kind(), "device"); + + TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds, + executable->GetOutputMemoryKinds()); + EXPECT_EQ(memory_kinds.size(), 1); + EXPECT_EQ(memory_kinds[0].size(), 1); + EXPECT_EQ(memory_kinds[0][0], "device"); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, executable->Execute({{input.get()}}, ExecuteOptions())); + std::vector>& result_buffers = result[0]; + EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device"); + Shape result_shape = result_buffers[0]->on_device_shape(); + auto memory_space = result_shape.layout().memory_space(); + EXPECT_EQ(memory_space, 1); +} + TEST(StreamExecutorGpuClientTest, ExecutablePinnedHostTupleOutputMemoryKindTest) { TF_ASSERT_OK_AND_ASSIGN(auto client, diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index d79b461f1ff86..c9e1c61cd56a3 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -2286,6 +2286,7 @@ absl::StatusOr> OutputBufferHelper( device->default_memory_space().value_or(nullptr); if (shape.has_layout()) { switch (shape.layout().memory_space()) { + case Layout::kGenericFastMemorySpace: case Layout::kDefaultMemorySpace: // Nothing to do, we have already set the default memory space. break; @@ -3322,6 +3323,7 @@ absl::StatusOr MemoryKindFromSimpleShape( switch (shape.layout().memory_space()) { case Layout::kHostMemorySpace: return PinnedHostMemorySpace::kKind; + case Layout::kGenericFastMemorySpace: case Layout::kDefaultMemorySpace: return default_memory_kind; default: