Skip to content

Commit

Permalink
fix RunWithExternalStream contex switch bug (PaddlePaddle#57629)
Browse files Browse the repository at this point in the history
* fix RunWithExternalStream contex switch bug
  • Loading branch information
ming1753 authored Sep 25, 2023
1 parent 602d73a commit 2fb4236
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
Expand Down Expand Up @@ -2219,6 +2220,8 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
UpdatePrivateDeviceContext(gpu_context, gpu_resource, place_);
return std::unique_ptr<phi::DeviceContext>(gpu_context);
}));
auto &pool = paddle::experimental::DeviceContextPool::Instance();
pool.SyncDeviceContext(place_);
}

return ZeroCopyRun();
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/include/context_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class PADDLE_API DeviceContextPool {

phi::DeviceContext* GetMutable(const Place& place);

void SyncDeviceContext(const Place& place);

template <AllocationType T>
const typename DefaultDeviceContextType<T>::TYPE* Get(const Place& place) {
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/lib/context_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ limitations under the License. */
namespace paddle {
namespace experimental {

void DeviceContextPool::SyncDeviceContext(const Place& place) {
if (!phi::DeviceContextPool::IsInitialized()) {
phi::memory_utils::InitDevices();
}
// only when we need the specific DeviceContext, get and cache it
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(place);
{
std::lock_guard<std::mutex> lock(mutex_);
context_map_[place] = dev_ctx;
}
}

DeviceContextPool& DeviceContextPool::Instance() {
static DeviceContextPool g_device_context_pool;
return g_device_context_pool;
Expand Down
55 changes: 55 additions & 0 deletions test/cpp/inference/api/analysis_predictor_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,61 @@ TEST(Predictor, Streams) {
CHECK_NE(stream, stream2);
}
}

TEST(Tensor, RunWithExternalStream) {
Config config;
config.SetModel(FLAGS_dirname);
config.EnableUseGpu(100, 0);
cudaStream_t stream;
cudaStreamCreate(&stream);
config.SetExecStream(stream);
auto predictor = CreatePredictor(config);

auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw");
auto w2 = predictor->GetInputHandle("thirdw");
auto w3 = predictor->GetInputHandle("forthw");

std::vector<std::vector<int64_t>> input_data(4, {0, 1, 2, 3});
std::vector<int64_t*> input_gpu(4, nullptr);

for (size_t i = 0; i < 4; ++i) {
cudaMalloc(reinterpret_cast<void**>(&input_gpu[i]), 4 * sizeof(int64_t));
cudaMemcpy(input_gpu[i],
input_data[i].data(),
4 * sizeof(int64_t),
cudaMemcpyHostToDevice);
}

w0->ShareExternalData<int64_t>(input_gpu[0], {4, 1}, PlaceType::kGPU);
w1->ShareExternalData<int64_t>(input_gpu[1], {4, 1}, PlaceType::kGPU);
w2->ShareExternalData<int64_t>(input_gpu[2], {4, 1}, PlaceType::kGPU);
w3->ShareExternalData<int64_t>(input_gpu[3], {4, 1}, PlaceType::kGPU);

auto out = predictor->GetOutputHandle("fc_1.tmp_2");
auto out_shape = out->shape();
float* out_data = nullptr;
auto out_size =
std::accumulate(
out_shape.begin(), out_shape.end(), 1, std::multiplies<int>()) *
sizeof(float);
cudaMalloc(reinterpret_cast<void**>(out_data), out_size * sizeof(float));
out->ShareExternalData<float>(out_data, out_shape, PlaceType::kGPU);

cudaStream_t external_stream;
cudaStreamCreate(&external_stream);
Config tmp_config(config);
tmp_config.SetExecStream(external_stream);
predictor->Run();
paddle_infer::experimental::InternalUtils::RunWithExternalStream(
predictor.get(), external_stream);

PlaceType place;
int size = 0;
out->data<float>(&place, &size);
LOG(INFO) << "output size: " << size / sizeof(float);
predictor->TryShrinkMemory();
}
#endif

TEST(AnalysisPredictor, OutputTensorHookFunc) {
Expand Down

0 comments on commit 2fb4236

Please sign in to comment.