From 34165713fd1ec82b2225440bd056018f4c275180 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 10 Mar 2024 21:23:00 +0800 Subject: [PATCH] feat: support temperature 0 --- .../models/medusa_plugin/medusa_head.cc | 51 +++++++++++++++++-- .../models/medusa_plugin/medusa_head.h | 5 +- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/src/turbomind/models/medusa_plugin/medusa_head.cc b/src/turbomind/models/medusa_plugin/medusa_head.cc index d5c9652e5..9febb2264 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.cc +++ b/src/turbomind/models/medusa_plugin/medusa_head.cc @@ -3,6 +3,7 @@ // Zhiwei Bao #include "src/turbomind/models/medusa_plugin/medusa_head.h" +#include "src/turbomind/kernels/sampling_topk_kernels.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/cublasMMWrapper.h" @@ -36,14 +37,18 @@ void MedusaHead::forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight& medusa_weight) { - const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; - const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); - T* medusa_head_logits_ptr = output_tensors->at("medusa_head_output").getPtr(); + const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; + const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); + int* h_topk_output_ids = output_tensors->at("medusa_head_output").getPtr(); + + allocate_buffer(batch_size); // TODO parallelize this loop for (int i = 0; i < medusa_num_heads_; i++) { - T* medusa_head_logits = medusa_head_logits_ptr + i * batch_size * vocab_size_; + T* medusa_head_logits = medusa_head_logits_buf_ + i * batch_size * vocab_size_; forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i); } + + top_k(h_topk_output_ids, medusa_head_logits_buf_, batch_size * medusa_num_heads_); } template @@ -53,7 +58,6 @@ void MedusaHead::forward(T* medusa_head_output, const MedusaWeight& medusa_weight, int head_id) { - allocate_buffer(batch_size); // TODO support multi medusa_num_layers resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]); linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]); @@ -72,6 +76,8 @@ void MedusaHead::allocate_buffer(size_t batch_size) { resblock_buf_ = (T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false); + medusa_head_logits_buf_ = (T*)allocator_->reMalloc( + medusa_head_logits_buf_, medusa_num_heads_ * sizeof(T) * batch_size * vocab_size_, false); is_allocated_buffer_ = true; } @@ -80,10 +86,45 @@ void MedusaHead::free_buffer() { if (is_free_buffer_after_forward_ && is_allocated_buffer_) { allocator_->free((void**)&resblock_buf_); + allocator_->free((void**)&workspace_buf_); + allocator_->free((void**)&medusa_head_logits_buf_); is_allocated_buffer_ = false; } } +template +void MedusaHead::top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k) +{ + size_t workspace_size_now = 0; + invokeBatchTopKOnly(nullptr, + workspace_size_now, + d_input_logits, + nullptr, + k, + nullptr, + vocab_size_, + nullptr, + stream_, + batch_size, + nullptr); + workspace_buf_ = (void*)allocator_->reMalloc(workspace_buf_, workspace_size_now, false); + invokeBatchTopKOnly(workspace_buf_, + workspace_size_now, + d_input_logits, + nullptr, + k, + nullptr, + vocab_size_, + nullptr, + stream_, + batch_size, + nullptr); + int offset = (int)(ceil(batch_size * vocab_size_ / 4.)) * 4; + int output_size = (int)(ceil(batch_size * k / 4.)) * 4; + int* topk_output_ids = (int*)(((T*)workspace_buf_) + offset); + cudaMemcpy(h_topk_output_ids, topk_output_ids, sizeof(int) * output_size, cudaMemcpyDeviceToHost); +} + template class MedusaHead; template class MedusaHead; #ifdef ENABLE_BF16 diff --git a/src/turbomind/models/medusa_plugin/medusa_head.h b/src/turbomind/models/medusa_plugin/medusa_head.h index 44d51035e..62c7f618e 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.h +++ b/src/turbomind/models/medusa_plugin/medusa_head.h @@ -38,6 +38,7 @@ class MedusaHead { private: void allocate_buffer(size_t batch_size); void free_buffer(); + void top_k(int* h_topk_output_ids, const T* d_input_logits, const size_t batch_size, const int k = 1); private: size_t in_size_; @@ -47,7 +48,9 @@ class MedusaHead { std::unique_ptr> resblock_; std::unique_ptr> linear_; - T* resblock_buf_; + T* resblock_buf_; + void* workspace_buf_; + T* medusa_head_logits_buf_; cudaStream_t stream_; cublasMMWrapper* cublas_wrapper_;