Skip to content

Commit

Permalink
handleOptArg with stream support
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Oct 26, 2023
1 parent a295fbf commit 07257b7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
13 changes: 10 additions & 3 deletions src/turbomind/kernels/gpt_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <unordered_map>

#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -131,14 +132,20 @@ void invokeFindContextDups(int* shared_contexts,
cudaStream_t stream = 0);

template<typename T>
void handleOptArg(TensorMap* input_tensors, const std::string& arg_name, T* d_ptr, T default_value, size_t size)
void handleOptArg(TensorMap* input_tensors,
const std::string& arg_name,
T* d_ptr,
T default_value,
size_t size,
cudaStream_t stream = {})
{
if (input_tensors->isExist(arg_name)) {
FT_CHECK(input_tensors->at(arg_name).size() == size);
cudaH2Dcpy(d_ptr, input_tensors->at(arg_name).getPtr<const T>(), size);
check_cuda_error(cudaMemcpyAsync(
d_ptr, input_tensors->at(arg_name).getPtr<const T>(), sizeof(T) * size, cudaMemcpyDefault, stream));
}
else {
deviceFill(d_ptr, size, default_value);
deviceFill(d_ptr, size, default_value, stream);
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,7 @@ void LlamaBatch<T>::initializeSampling(int infer_request_count)
}
}

handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_);
cudaStreamSynchronize(0);
handleOptArg(&inputs_, "end_id", end_ids_buf_, llama_->end_id_, batch_size_, stream_);
}

template<typename T>
Expand Down
7 changes: 3 additions & 4 deletions src/turbomind/utils/memory_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,9 @@ template void deviceFree(__nv_fp8_e4m3*& ptr);
template<typename T>
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream)
{
T* arr = new T[size];
std::fill(arr, arr + size, value);
check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream));
delete[] arr;
std::unique_ptr<T[]> arr(new T[size]);
std::fill(arr.get(), arr.get() + size, value);
check_cuda_error(cudaMemcpyAsync(devptr, arr.get(), sizeof(T) * size, cudaMemcpyDefault, stream));
}

template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream);
Expand Down

0 comments on commit 07257b7

Please sign in to comment.