Skip to content

Commit

Permalink
feat: update medusa_head_output
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Mar 10, 2024
1 parent e8254ab commit 8cf5a73
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions src/turbomind/models/medusa_plugin/medusa_head.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ namespace turbomind {

template<typename T>
MedusaHead<T>::MedusaHead(size_t in_size,
size_t out_size,
size_t vocab_size,
size_t medusa_num_heads,
cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
NcclParam tensor_para,
bool is_free_buffer_after_forward):
in_size_(in_size),
out_size_(out_size),
vocab_size_(vocab_size),
medusa_num_heads_(medusa_num_heads),
stream_(stream),
cublas_wrapper_(cublas_wrapper),
Expand All @@ -36,12 +36,12 @@ void MedusaHead<T>::forward(TensorMap* output_tensors,
const TensorMap* input_tensors,
const MedusaWeight<T>& medusa_weight)
{
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0];
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>();
std::vector<T*>* medusa_head_logits_vec = output_tensors->at("medusa_head_output").getPtr<std::vector<T*>>();
const size_t batch_size = input_tensors->at("medusa_head_input").shape[0];
const T* hidden_states = input_tensors->at("medusa_head_input").getPtr<T>();
T* medusa_head_logits_ptr = output_tensors->at("medusa_head_output").getPtr<T>();
// TODO parallelize this loop
for (int i = 0; i < medusa_num_heads_; i++) {
T* medusa_head_logits = (*medusa_head_logits_vec)[i];
T* medusa_head_logits = medusa_head_logits_ptr + i * batch_size * vocab_size_;
forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i);
}
}
Expand All @@ -60,7 +60,7 @@ void MedusaHead<T>::forward(T* medusa_head_output,

if (tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * out_size_, tensor_para_, stream_);
ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * vocab_size_, tensor_para_, stream_);
sync_check_cuda_error();
}

Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/medusa_plugin/medusa_head.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MedusaHead {

private:
size_t in_size_;
size_t out_size_;
size_t vocab_size_;
size_t medusa_num_heads_;

std::unique_ptr<ResBlock<T>> resblock_;
Expand Down

0 comments on commit 8cf5a73

Please sign in to comment.