From 2fb42362b1120d53a56ef79d6c1f3e24590e9b85 Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Mon, 25 Sep 2023 11:40:21 +0800 Subject: [PATCH] fix RunWithExternalStream contex switch bug (#57629) * fix RunWithExternalStream contex switch bug --- .../fluid/inference/api/analysis_predictor.cc | 3 + paddle/phi/api/include/context_pool.h | 2 + paddle/phi/api/lib/context_pool.cc | 12 ++++ .../api/analysis_predictor_tester.cc | 55 +++++++++++++++++++ 4 files changed, 72 insertions(+) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 70da22a3240e98..f30e2c560b57ff 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -56,6 +56,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" @@ -2219,6 +2220,8 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { UpdatePrivateDeviceContext(gpu_context, gpu_resource, place_); return std::unique_ptr(gpu_context); })); + auto &pool = paddle::experimental::DeviceContextPool::Instance(); + pool.SyncDeviceContext(place_); } return ZeroCopyRun(); diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h index 7afe17ba8419d3..6b6fe290d6d288 100644 --- a/paddle/phi/api/include/context_pool.h +++ b/paddle/phi/api/include/context_pool.h @@ -71,6 +71,8 @@ class PADDLE_API DeviceContextPool { phi::DeviceContext* GetMutable(const Place& place); + void SyncDeviceContext(const Place& place); + template const typename DefaultDeviceContextType::TYPE* Get(const Place& place) { return reinterpret_cast::TYPE*>( diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index 292bd8a7e47aa5..8066147025117a 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -26,6 +26,18 @@ limitations under the License. */ namespace paddle { namespace experimental { +void DeviceContextPool::SyncDeviceContext(const Place& place) { + if (!phi::DeviceContextPool::IsInitialized()) { + phi::memory_utils::InitDevices(); + } + // only when we need the specific DeviceContext, get and cache it + auto* dev_ctx = phi::DeviceContextPool::Instance().Get(place); + { + std::lock_guard lock(mutex_); + context_map_[place] = dev_ctx; + } +} + DeviceContextPool& DeviceContextPool::Instance() { static DeviceContextPool g_device_context_pool; return g_device_context_pool; diff --git a/test/cpp/inference/api/analysis_predictor_tester.cc b/test/cpp/inference/api/analysis_predictor_tester.cc index 35c07c3a83790c..f32d509d62d8b3 100644 --- a/test/cpp/inference/api/analysis_predictor_tester.cc +++ b/test/cpp/inference/api/analysis_predictor_tester.cc @@ -663,6 +663,61 @@ TEST(Predictor, Streams) { CHECK_NE(stream, stream2); } } + +TEST(Tensor, RunWithExternalStream) { + Config config; + config.SetModel(FLAGS_dirname); + config.EnableUseGpu(100, 0); + cudaStream_t stream; + cudaStreamCreate(&stream); + config.SetExecStream(stream); + auto predictor = CreatePredictor(config); + + auto w0 = predictor->GetInputHandle("firstw"); + auto w1 = predictor->GetInputHandle("secondw"); + auto w2 = predictor->GetInputHandle("thirdw"); + auto w3 = predictor->GetInputHandle("forthw"); + + std::vector> input_data(4, {0, 1, 2, 3}); + std::vector input_gpu(4, nullptr); + + for (size_t i = 0; i < 4; ++i) { + cudaMalloc(reinterpret_cast(&input_gpu[i]), 4 * sizeof(int64_t)); + cudaMemcpy(input_gpu[i], + input_data[i].data(), + 4 * sizeof(int64_t), + cudaMemcpyHostToDevice); + } + + w0->ShareExternalData(input_gpu[0], {4, 1}, PlaceType::kGPU); + w1->ShareExternalData(input_gpu[1], {4, 1}, PlaceType::kGPU); + w2->ShareExternalData(input_gpu[2], {4, 1}, PlaceType::kGPU); + w3->ShareExternalData(input_gpu[3], {4, 1}, PlaceType::kGPU); + + auto out = predictor->GetOutputHandle("fc_1.tmp_2"); + auto out_shape = out->shape(); + float* out_data = nullptr; + auto out_size = + std::accumulate( + out_shape.begin(), out_shape.end(), 1, std::multiplies()) * + sizeof(float); + cudaMalloc(reinterpret_cast(out_data), out_size * sizeof(float)); + out->ShareExternalData(out_data, out_shape, PlaceType::kGPU); + + cudaStream_t external_stream; + cudaStreamCreate(&external_stream); + Config tmp_config(config); + tmp_config.SetExecStream(external_stream); + predictor->Run(); + paddle_infer::experimental::InternalUtils::RunWithExternalStream( + predictor.get(), external_stream); + + PlaceType place; + int size = 0; + out->data(&place, &size); + LOG(INFO) << "output size: " << size / sizeof(float); + predictor->TryShrinkMemory(); +} #endif TEST(AnalysisPredictor, OutputTensorHookFunc) {