From 8cf5a73a2c514fa9d2a25f0d5927652d8122c029 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 10 Mar 2024 16:47:23 +0800 Subject: [PATCH] feat: update medusa_head_output --- src/turbomind/models/medusa_plugin/medusa_head.cc | 14 +++++++------- src/turbomind/models/medusa_plugin/medusa_head.h | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/turbomind/models/medusa_plugin/medusa_head.cc b/src/turbomind/models/medusa_plugin/medusa_head.cc index 91d63aeea..d5c9652e5 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.cc +++ b/src/turbomind/models/medusa_plugin/medusa_head.cc @@ -11,7 +11,7 @@ namespace turbomind { template MedusaHead::MedusaHead(size_t in_size, - size_t out_size, + size_t vocab_size, size_t medusa_num_heads, cudaStream_t stream, cublasMMWrapper* cublas_wrapper, @@ -19,7 +19,7 @@ MedusaHead::MedusaHead(size_t in_size, NcclParam tensor_para, bool is_free_buffer_after_forward): in_size_(in_size), - out_size_(out_size), + vocab_size_(vocab_size), medusa_num_heads_(medusa_num_heads), stream_(stream), cublas_wrapper_(cublas_wrapper), @@ -36,12 +36,12 @@ void MedusaHead::forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight& medusa_weight) { - const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; - const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); - std::vector* medusa_head_logits_vec = output_tensors->at("medusa_head_output").getPtr>(); + const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; + const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); + T* medusa_head_logits_ptr = output_tensors->at("medusa_head_output").getPtr(); // TODO parallelize this loop for (int i = 0; i < medusa_num_heads_; i++) { - T* medusa_head_logits = (*medusa_head_logits_vec)[i]; + T* medusa_head_logits = medusa_head_logits_ptr + i * batch_size * vocab_size_; forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i); } } @@ -60,7 +60,7 @@ void MedusaHead::forward(T* medusa_head_output, if (tensor_para_.world_size_ > 1) { NcclGuard nccl_guard(tensor_para_, stream_); - ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * out_size_, tensor_para_, stream_); + ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * vocab_size_, tensor_para_, stream_); sync_check_cuda_error(); } diff --git a/src/turbomind/models/medusa_plugin/medusa_head.h b/src/turbomind/models/medusa_plugin/medusa_head.h index 5fc7c86b2..44d51035e 100644 --- a/src/turbomind/models/medusa_plugin/medusa_head.h +++ b/src/turbomind/models/medusa_plugin/medusa_head.h @@ -41,7 +41,7 @@ class MedusaHead { private: size_t in_size_; - size_t out_size_; + size_t vocab_size_; size_t medusa_num_heads_; std::unique_ptr> resblock_;