Skip to content

Commit

Permalink
feat: support temperature 0
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Mar 10, 2024
1 parent 8cf5a73 commit 61cb292
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
51 changes: 46 additions & 5 deletions src/turbomind/models/medusa_plugin/medusa_head.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Zhiwei Bao <[email protected]>

#include "src/turbomind/models/medusa_plugin/medusa_head.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
Expand Down Expand Up @@ -36,14 +37,18 @@ void MedusaHead<T>::forward(TensorMap* output_tensors,
const TensorMap* input_tensors,
const MedusaWeight<T>& medusa_weight)
{
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0];
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>();
T* medusa_head_logits_ptr = output_tensors->at("medusa_head_output").getPtr<T>();
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0];
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>();
int* h_topk_output_ids = output_tensors->at("medusa_head_output").getPtr<int>();

allocate_buffer(batch_size);
// TODO parallelize this loop
for (int i = 0; i < medusa_num_heads_; i++) {
T* medusa_head_logits = medusa_head_logits_ptr + i * batch_size * vocab_size_;
T* medusa_head_logits = medusa_head_logits_buf_ + i * batch_size * vocab_size_;
forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i);
}

top_k(h_topk_output_ids, medusa_head_logits_buf_, batch_size * medusa_num_heads_);
}

template<typename T>
Expand All @@ -53,7 +58,6 @@ void MedusaHead<T>::forward(T* medusa_head_output,
const MedusaWeight<T>& medusa_weight,
int head_id)
{
allocate_buffer(batch_size);
// TODO support multi medusa_num_layers
resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]);
linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]);
Expand All @@ -72,6 +76,8 @@ void MedusaHead<T>::allocate_buffer(size_t batch_size)
{
resblock_buf_ =
(T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false);
medusa_head_logits_buf_ = (T*)allocator_->reMalloc(
medusa_head_logits_buf_, medusa_num_heads_ * sizeof(T) * batch_size * vocab_size_, false);
is_allocated_buffer_ = true;
}

Expand All @@ -80,10 +86,45 @@ void MedusaHead<T>::free_buffer()
{
if (is_free_buffer_after_forward_ && is_allocated_buffer_) {
allocator_->free((void**)&resblock_buf_);
allocator_->free((void**)&workspace_buf_);
allocator_->free((void**)&medusa_head_logits_buf_);
is_allocated_buffer_ = false;
}
}

template<typename T>
void MedusaHead<T>::top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k)
{
size_t workspace_size_now = 0;
invokeBatchTopKOnly(nullptr,
workspace_size_now,
d_input_logits,
nullptr,
k,
nullptr,
vocab_size_,
nullptr,
stream_,
batch_size,
nullptr);
workspace_buf_ = (void*)allocator_->reMalloc(workspace_buf_, workspace_size_now, false);
invokeBatchTopKOnly(workspace_buf_,
workspace_size_now,
d_input_logits,
nullptr,
k,
nullptr,
vocab_size_,
nullptr,
stream_,
batch_size,
nullptr);
int offset = (int)(ceil(batch_size * vocab_size_ / 4.)) * 4;
int output_size = (int)(ceil(batch_size * k / 4.)) * 4;
int* topk_output_ids = (int*)(((T*)workspace_buf_) + offset);
cudaMemcpy(h_topk_output_ids, topk_output_ids, sizeof(int) * output_size, cudaMemcpyDeviceToHost);
}

template class MedusaHead<float>;
template class MedusaHead<half>;
#ifdef ENABLE_BF16
Expand Down
5 changes: 4 additions & 1 deletion src/turbomind/models/medusa_plugin/medusa_head.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MedusaHead {
private:
void allocate_buffer(size_t batch_size);
void free_buffer();
void top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k = 1);

private:
size_t in_size_;
Expand All @@ -47,7 +48,9 @@ class MedusaHead {
std::unique_ptr<ResBlock<T>> resblock_;
std::unique_ptr<LlamaLinear<T>> linear_;

T* resblock_buf_;
T* resblock_buf_;
void* workspace_buf_;
T* medusa_head_logits_buf_;

cudaStream_t stream_;
cublasMMWrapper* cublas_wrapper_;
Expand Down

0 comments on commit 61cb292

Please sign in to comment.