From 12cd0e42482990dcd006fa7fc6841df368d5414e Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 26 Oct 2023 11:09:21 +0000 Subject: [PATCH] sycn before using async data --- src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu | 2 ++ src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu | 1 + 2 files changed, 3 insertions(+) diff --git a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu index 63dae19888..de53cfd904 100644 --- a/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu +++ b/src/turbomind/layers/sampling_layers/TopKSamplingLayer.cu @@ -22,6 +22,7 @@ #include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h" #include "src/turbomind/macro.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/memory_utils.h" @@ -202,6 +203,7 @@ void TopKSamplingLayer::setup(const size_t batch_size, const size_t beam_widt cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_); uint* runtime_top_ks = new uint[batch_size]; cudaAutoCpy(runtime_top_ks, runtime_top_k_buf_, batch_size, stream_); + check_cuda_error(cudaStreamSynchronize(stream_)); runtime_max_top_k_ = static_cast(*std::max_element(runtime_top_ks, runtime_top_ks + batch_size)); delete[] runtime_top_ks; } diff --git a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu index 8e7e97314f..ef05708526 100644 --- a/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu +++ b/src/turbomind/layers/sampling_layers/TopPSamplingLayer.cu @@ -249,6 +249,7 @@ void TopPSamplingLayer::setup(const size_t batch_size, const size_t beam_widt cudaAutoCpy(skip_decode_, skip_decode_buf_, batch_size, stream_); float* runtime_top_ps = new float[batch_size]; cudaAutoCpy(runtime_top_ps, runtime_top_p_buf_, batch_size, stream_); + check_cuda_error(cudaStreamSynchronize(stream_)); runtime_max_top_p_ = *std::max_element(runtime_top_ps, runtime_top_ps + batch_size); delete[] runtime_top_ps; }