diff --git a/src/turbomind/kernels/activation_kernels.cu b/src/turbomind/kernels/activation_kernels.cu index 9d93b05566..db66a31c94 100644 --- a/src/turbomind/kernels/activation_kernels.cu +++ b/src/turbomind/kernels/activation_kernels.cu @@ -329,4 +329,82 @@ INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half); INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16); #endif +template class Activation, typename T, typename BT> +__global__ void fused_bias_residual_activation( + T* out, const BT* __restrict bias, const T* __restrict residual, int m, int n, int tp_num, int tp_offset) +{ + const bool with_bias = bias != nullptr; + const bool with_residual = residual != nullptr; + + for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < 1LL * m * n; id += blockDim.x * gridDim.x) { + T val; + + val = out[id]; + + if (with_bias) { + T bias_val = static_cast(bias[id % n]); + val = add(val, bias_val); + } + + val = cuda_cast(Activation::apply(val)); + + if (with_residual) { + T residual_val = static_cast(residual[id % n + (id - id % n) * tp_num + tp_offset]); + val = add(val, residual_val); + } + + out[id] = val; + } +} + +template class Activation, typename T, typename BT> +void invokeFusedBiasResidualActivation(T* out, + const BT* bias, + const T* residual, + const int m, + const int n, + cudaStream_t stream, + const int tp_num, + const int tp_offset) +{ + TM_LOG_DEBUG(__PRETTY_FUNCTION__); + using PT = typename packed_type::type; + constexpr int packed_elems = num_elems::value; + using PBT = typename packed_as::type; + + dim3 block, grid; + if (n / 4 / packed_elems <= 1024) { + block.x = n / 4 / packed_elems; + grid.x = m; + } + else { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + fused_bias_residual_activation<<>>(reinterpret_cast(out), + reinterpret_cast(bias), + reinterpret_cast(residual), + m, + n / packed_elems, + tp_num, + tp_offset / packed_elems); + sync_check_cuda_error(); +} + +#define INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(Activation, T, BT) \ + template void invokeFusedBiasResidualActivation(T * out, \ + const BT* bias, \ + const T* residual, \ + const int m, \ + const int n, \ + cudaStream_t stream, \ + const int tp_num, \ + const int tp_offset); + +INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, float, float); +INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, half, half); +#ifdef ENABLE_BF16 +INSTANTIATE_FUSED_BIAS_RESIDUAL_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16); +#endif + } // namespace turbomind diff --git a/src/turbomind/kernels/activation_kernels.h b/src/turbomind/kernels/activation_kernels.h index 776b614c9c..5844eaed1f 100644 --- a/src/turbomind/kernels/activation_kernels.h +++ b/src/turbomind/kernels/activation_kernels.h @@ -107,4 +107,14 @@ void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStre template void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream); +template class Activation, typename T, typename BT> +void invokeFusedBiasResidualActivation(T* out, + const BT* bias, + const T* residual, + const int m, + const int n, + cudaStream_t stream, + const int tp_num, + const int tp_offset); + } // namespace turbomind diff --git a/src/turbomind/models/CMakeLists.txt b/src/turbomind/models/CMakeLists.txt index 37d883f86f..932e23a994 100644 --- a/src/turbomind/models/CMakeLists.txt +++ b/src/turbomind/models/CMakeLists.txt @@ -13,3 +13,4 @@ # limitations under the License. add_subdirectory(llama) +add_subdirectory(medusa_plugin) diff --git a/src/turbomind/models/medusa_plugin/CMakeLists.txt b/src/turbomind/models/medusa_plugin/CMakeLists.txt new file mode 100644 index 0000000000..2ab5eaebcf --- /dev/null +++ b/src/turbomind/models/medusa_plugin/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +cmake_minimum_required(VERSION 3.8) + +find_package(CUDAToolkit REQUIRED) + +add_library(Medusa STATIC + medusa_weight.cc + res_block.cc + medusa_head.cc) + +set_property(TARGET Medusa PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/turbomind/models/medusa_plugin/medusa_head.cc b/src/turbomind/models/medusa_plugin/medusa_head.cc new file mode 100644 index 0000000000..91d63aeeaf --- /dev/null +++ b/src/turbomind/models/medusa_plugin/medusa_head.cc @@ -0,0 +1,93 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Yineng Zhang +// Zhiwei Bao + +#include "src/turbomind/models/medusa_plugin/medusa_head.h" +#include "src/turbomind/models/llama/LlamaNcclGuard.h" +#include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/cublasMMWrapper.h" + +namespace turbomind { + +template +MedusaHead::MedusaHead(size_t in_size, + size_t out_size, + size_t medusa_num_heads, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + NcclParam tensor_para, + bool is_free_buffer_after_forward): + in_size_(in_size), + out_size_(out_size), + medusa_num_heads_(medusa_num_heads), + stream_(stream), + cublas_wrapper_(cublas_wrapper), + allocator_(allocator), + tensor_para_(tensor_para), + is_free_buffer_after_forward_(is_free_buffer_after_forward) +{ + resblock_ = std::make_unique>(in_size_, stream_, cublas_wrapper_, tensor_para_); + linear_ = std::make_unique>(cublas_wrapper_, stream_); +} + +template +void MedusaHead::forward(TensorMap* output_tensors, + const TensorMap* input_tensors, + const MedusaWeight& medusa_weight) +{ + const size_t batch_size = input_tensors->at("medusa_head_input").shape[0]; + const T* hidden_states = input_tensors->at("medusa_head_input").getPtr(); + std::vector* medusa_head_logits_vec = output_tensors->at("medusa_head_output").getPtr>(); + // TODO parallelize this loop + for (int i = 0; i < medusa_num_heads_; i++) { + T* medusa_head_logits = (*medusa_head_logits_vec)[i]; + forward(medusa_head_logits, hidden_states, batch_size, medusa_weight, i); + } +} + +template +void MedusaHead::forward(T* medusa_head_output, + const T* medusa_head_input, + size_t batch_size, + const MedusaWeight& medusa_weight, + int head_id) +{ + allocate_buffer(batch_size); + // TODO support multi medusa_num_layers + resblock_->forward(resblock_buf_, medusa_head_input, batch_size, medusa_weight.get_resblocks_weights()[head_id][0]); + linear_->forward(medusa_head_output, resblock_buf_, batch_size, medusa_weight.get_heads_weights()[head_id]); + + if (tensor_para_.world_size_ > 1) { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllReduceSum(medusa_head_output, medusa_head_output, batch_size * out_size_, tensor_para_, stream_); + sync_check_cuda_error(); + } + + free_buffer(); +} + +template +void MedusaHead::allocate_buffer(size_t batch_size) +{ + resblock_buf_ = + (T*)allocator_->reMalloc(resblock_buf_, sizeof(T) * batch_size * in_size_ / tensor_para_.world_size_, false); + is_allocated_buffer_ = true; +} + +template +void MedusaHead::free_buffer() +{ + if (is_free_buffer_after_forward_ && is_allocated_buffer_) { + allocator_->free((void**)&resblock_buf_); + is_allocated_buffer_ = false; + } +} + +template class MedusaHead; +template class MedusaHead; +#ifdef ENABLE_BF16 +template class MedusaHead<__nv_bfloat16>; +#endif + +} // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/medusa_head.h b/src/turbomind/models/medusa_plugin/medusa_head.h new file mode 100644 index 0000000000..5fc7c86b2c --- /dev/null +++ b/src/turbomind/models/medusa_plugin/medusa_head.h @@ -0,0 +1,61 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Yineng Zhang +// Zhiwei Bao + +#pragma once + +#include "src/turbomind/models/medusa_plugin/medusa_weight.h" +#include "src/turbomind/models/medusa_plugin/res_block.h" +#include "src/turbomind/utils/cublasMMWrapper.h" +#include "src/turbomind/utils/nccl_utils.h" +#include +#include + +namespace turbomind { + +template +class MedusaHead { +public: + MedusaHead(size_t in_size, + size_t out_size, + size_t medusa_num_heads, + cudaStream_t stream, + cublasMMWrapper* cublas_wrapper, + IAllocator* allocator, + NcclParam tensor_para, + bool is_free_buffer_after_forward = false); + ~MedusaHead() = default; + MedusaHead(const MedusaHead&) = delete; + MedusaHead& operator=(const MedusaHead&) = delete; + + void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const MedusaWeight& medusa_weight); + void forward(T* medusa_head_output, + const T* medusa_head_input, + size_t batch_size, + const MedusaWeight& medusa_weight, + int head_id); + +private: + void allocate_buffer(size_t batch_size); + void free_buffer(); + +private: + size_t in_size_; + size_t out_size_; + size_t medusa_num_heads_; + + std::unique_ptr> resblock_; + std::unique_ptr> linear_; + + T* resblock_buf_; + + cudaStream_t stream_; + cublasMMWrapper* cublas_wrapper_; + IAllocator* allocator_; + + NcclParam tensor_para_; + + bool is_allocated_buffer_ = false; + bool is_free_buffer_after_forward_ = false; +}; +} // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/medusa_weight.cc b/src/turbomind/models/medusa_plugin/medusa_weight.cc new file mode 100644 index 0000000000..04113fe29d --- /dev/null +++ b/src/turbomind/models/medusa_plugin/medusa_weight.cc @@ -0,0 +1,172 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Yineng Zhang +// Zhiwei Bao + +#include "src/turbomind/models/medusa_plugin/medusa_weight.h" +#include "src/turbomind/utils/memory_utils.h" +#include +#include +#include +#include + +namespace turbomind { + +template +MedusaWeight::MedusaWeight(size_t medusa_num_heads, + size_t medusa_num_layers, + size_t hidden_size, + size_t vocab_size, + WeightType weight_type, + size_t tensor_para_size, + size_t tensor_para_rank): + medusa_num_heads_(medusa_num_heads), + medusa_num_layers_(medusa_num_layers), + hidden_size_(hidden_size), + vocab_size_(vocab_size), + weight_type_(weight_type), + tensor_para_size_(tensor_para_size), + tensor_para_rank_(tensor_para_rank) +{ + heads_weights_.resize(medusa_num_heads_); + std::fill_n( + heads_weights_.begin(), + medusa_num_heads_, + LlamaDenseWeight{hidden_size_ / tensor_para_size_, vocab_size_, nullptr, weight_type_, nullptr, nullptr, 0}); + + resblocks_weights_.resize(medusa_num_heads_); + std::fill_n(resblocks_weights_.begin(), medusa_num_heads_, std::vector>(medusa_num_layers_)); + std::for_each(resblocks_weights_.begin(), resblocks_weights_.end(), [this](auto& resblock_weights) { + std::for_each(resblock_weights.begin(), resblock_weights.end(), [this](auto& resblock_weight) { + resblock_weight.input_dims = hidden_size_; + resblock_weight.output_dims = hidden_size_ / tensor_para_size_; + resblock_weight.type = weight_type_; + }); + }); + + malloc_weight(); +} + +template +MedusaWeight::~MedusaWeight() +{ + free_weight(); +} + +template +void MedusaWeight::malloc_weight(LlamaDenseWeight* weight, bool bias) +{ + if (bias) { + deviceMalloc((T**)&weight->bias, weight->output_dims); + } + const size_t bit_size = getBitSize(weight->type); + if (bit_size >= 16) { + deviceMalloc((T**)&weight->kernel, weight->input_dims * weight->output_dims); + } +} + +template +void MedusaWeight::malloc_weight() +{ + std::for_each(heads_weights_.begin(), heads_weights_.end(), [this](auto& head_weights) { + malloc_weight(&head_weights, false); + }); + std::for_each(resblocks_weights_.begin(), resblocks_weights_.end(), [this](auto& resblock_weights) { + std::for_each(resblock_weights.begin(), resblock_weights.end(), [this](auto& resblock_weight) { + malloc_weight(&resblock_weight, true); + }); + }); +} + +template +void MedusaWeight::free_weight(LlamaDenseWeight* weight) +{ + cudaFree(weight->kernel); + cudaFree(weight->bias); + cudaFree(weight->scales_and_zeros); + + weight->kernel = nullptr; + weight->bias = nullptr; + weight->scales_and_zeros = nullptr; +} + +template +void MedusaWeight::free_weight() +{ + std::for_each( + heads_weights_.begin(), heads_weights_.end(), [this](auto& head_weights) { free_weight(&head_weights); }); + std::for_each(resblocks_weights_.begin(), resblocks_weights_.end(), [this](auto& resblock_weights) { + std::for_each(resblock_weights.begin(), resblock_weights.end(), [this](auto& resblock_weight) { + free_weight(&resblock_weight); + }); + }); +} + +template +void MedusaWeight::load_weight(LlamaDenseWeight* weight, const std::string& path, FtCudaDataType model_file_type) +{ + const size_t bit_size = getBitSize(weight->type); + if (bit_size >= 16) { + loadWeightFromBin((T*)weight->kernel, {weight->input_dims, weight->output_dims}, path, model_file_type); + } +} + +template +void MedusaWeight::load_bias(LlamaDenseWeight* weight, const std::string& path, FtCudaDataType model_file_type) +{ + const size_t bit_size = getBitSize(weight->type); + if (bit_size >= 16) { + loadWeightFromBin((T*)weight->bias, {weight->output_dims}, path, model_file_type); + } +} + +template +void MedusaWeight::load_model(const std::string& dir_path, FtCudaDataType model_file_type) +{ + auto ends_with = [](std::string& text, const std::string& suffix) noexcept { + return suffix.empty() + || (text.size() >= suffix.size() + && std::memcmp(text.data() + (text.size() - suffix.size()), suffix.data(), suffix.size()) == 0); + }; + std::string weight_path = dir_path; + if (!ends_with(weight_path, "/")) { + weight_path.append("/"); + } + std::string prefix = "medusa."; + std::string rank = std::to_string(tensor_para_rank_) + "."; + weight_path.append(prefix); + for (int i = 0; i < medusa_num_heads_; i++) { + for (int j = 0; j < medusa_num_layers_; j++) { + std::stringstream ss; + ss << weight_path << i << "." << j << "." + << "linear." << rank; + std::string common_prefix = ss.str(); + + load_weight(&resblocks_weights_[i][j], common_prefix + "weight", model_file_type); + load_bias(&resblocks_weights_[i][j], common_prefix + "bias", model_file_type); + } + + std::stringstream ss; + ss << weight_path << i << "." << medusa_num_layers_ << "." << rank << "weight"; + load_weight(&heads_weights_[i], ss.str(), model_file_type); + } +} + +template +const std::vector>& MedusaWeight::get_heads_weights() const +{ + return heads_weights_; +} + +template +const std::vector>>& MedusaWeight::get_resblocks_weights() const +{ + return resblocks_weights_; +} + +template class MedusaWeight; +template class MedusaWeight; +#ifdef ENABLE_BF16 +template class MedusaWeight<__nv_bfloat16>; +#endif + +} // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/medusa_weight.h b/src/turbomind/models/medusa_plugin/medusa_weight.h new file mode 100644 index 0000000000..c44fa0fa18 --- /dev/null +++ b/src/turbomind/models/medusa_plugin/medusa_weight.h @@ -0,0 +1,54 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Yineng Zhang +// Zhiwei Bao + +#pragma once + +#include "src/turbomind/models/llama/LlamaDenseWeight.h" +#include "src/turbomind/utils/cuda_utils.h" +#include + +namespace turbomind { + +template +class MedusaWeight { +public: + MedusaWeight(size_t medusa_num_heads, + size_t medusa_num_layers, + size_t hidden_size, + size_t vocab_size, + WeightType weight_type, + size_t tensor_para_size, + size_t tensor_para_rank); + ~MedusaWeight(); + MedusaWeight(const MedusaWeight&) = delete; + MedusaWeight& operator=(const MedusaWeight&) = delete; + + const std::vector>& get_heads_weights() const; + const std::vector>>& get_resblocks_weights() const; + + void load_model(const std::string& dir_path, FtCudaDataType model_file_type); + +private: + void malloc_weight(LlamaDenseWeight* weight, bool bias); + void free_weight(LlamaDenseWeight* weight); + void malloc_weight(); + void free_weight(); + void load_weight(LlamaDenseWeight* weight, const std::string& path, FtCudaDataType model_file_type); + void load_bias(LlamaDenseWeight* weight, const std::string& path, FtCudaDataType model_file_type); + +private: + size_t medusa_num_heads_; + size_t medusa_num_layers_; + size_t hidden_size_; + size_t vocab_size_; + WeightType weight_type_; + + size_t tensor_para_size_; + size_t tensor_para_rank_; + + std::vector> heads_weights_; + std::vector>> resblocks_weights_; +}; + +} // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/res_block.cc b/src/turbomind/models/medusa_plugin/res_block.cc new file mode 100644 index 0000000000..0051907f12 --- /dev/null +++ b/src/turbomind/models/medusa_plugin/res_block.cc @@ -0,0 +1,47 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Yineng Zhang +// Zhiwei Bao + +#include "src/turbomind/models/medusa_plugin/res_block.h" +#include "src/turbomind/kernels/activation_kernels.h" +#include "src/turbomind/utils/Tensor.h" + +namespace turbomind { + +template +void ResBlock::forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaDenseWeight& weight) +{ + T* resblock_output = output_tensors->at("resblock_output").getPtr(); + const T* resblock_input = input_tensors->at("resblock_input").getPtr(); + const size_t batch_size = input_tensors->at("resblock_input").shape[0]; + + forward(resblock_output, resblock_input, batch_size, weight); +} + +template +void ResBlock::forward(T* resblock_output, + const T* resblock_input, + size_t batch_size, + const LlamaDenseWeight& weight) +{ + linear_->forward(resblock_output, resblock_input, batch_size, weight); + const int tp_num = tensor_para_.world_size_; + const int tp_offset = tensor_para_.rank_ * in_size_ / tp_num; + + invokeFusedBiasResidualActivation(resblock_output, + (const T*)weight.bias, // bias + (const T*)resblock_input, // residual + batch_size, // m + in_size_ / tp_num, // n + stream_, + tp_num, + tp_offset); +} + +template class ResBlock; +template class ResBlock; +#ifdef ENABLE_BF16 +template class ResBlock<__nv_bfloat16>; +#endif + +} // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/res_block.h b/src/turbomind/models/medusa_plugin/res_block.h new file mode 100644 index 0000000000..ddaa2307c0 --- /dev/null +++ b/src/turbomind/models/medusa_plugin/res_block.h @@ -0,0 +1,36 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Yineng Zhang +// Zhiwei Bao + +#pragma once + +#include "src/turbomind/models/llama/LlamaDenseWeight.h" +#include "src/turbomind/models/llama/LlamaLinear.h" +#include "src/turbomind/utils/cublasMMWrapper.h" +#include "src/turbomind/utils/nccl_utils.h" + +namespace turbomind { + +template +class ResBlock { +public: + ResBlock(size_t in_size, cudaStream_t stream, cublasMMWrapper* cublas_wrapper, NcclParam tensor_para): + in_size_(in_size), stream_(stream), tensor_para_(tensor_para) + { + linear_ = std::make_unique>(cublas_wrapper, stream); + } + ~ResBlock() = default; + ResBlock(const ResBlock&) = delete; + ResBlock& operator=(const ResBlock&) = delete; + + void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaDenseWeight& weight); + void forward(T* resblock_output, const T* resblock_input, size_t batch_size, const LlamaDenseWeight& weight); + +private: + size_t in_size_; + + cudaStream_t stream_; + std::unique_ptr> linear_; + NcclParam tensor_para_; +}; +} // namespace turbomind diff --git a/src/turbomind/models/medusa_plugin/tests/READMD.md b/src/turbomind/models/medusa_plugin/tests/READMD.md new file mode 100644 index 0000000000..5f6bc63934 --- /dev/null +++ b/src/turbomind/models/medusa_plugin/tests/READMD.md @@ -0,0 +1,24 @@ +# Usage + +```bash +# https://huggingface.co/FasterDecoding/medusa-vicuna-13b-v1.3 + +# fp16 tp1 +# default medusa pt path: /workdir/medusa-vicuna-13b-v1.3/medusa_lm_head.pt +# default medusa output path: /workdir/medusa_output/fp16/tp1 +# default tp: 1 +# default medusa weight type: fp16 +python3 medusa_converter.py + +# fp16 tp1 +python3 medusa_converter.py --medusa_pt_path=/workdir/medusa-vicuna-13b-v1.3/medusa_lm_head.pt --medusa_output_path=/workdir/medusa_output/fp16/tp1 --tp=1 --medusa_weight_type=fp16 + +# fp16 tp2 +python3 medusa_converter.py --medusa_pt_path=/workdir/medusa-vicuna-13b-v1.3/medusa_lm_head.pt --medusa_output_path=/workdir/medusa_output/fp16/tp2 --tp=2 --medusa_weight_type=fp16 + +# bf16 tp1 +python3 medusa_converter.py --medusa_pt_path=/workdir/medusa-vicuna-13b-v1.3/medusa_lm_head.pt --medusa_output_path=/workdir/medusa_output/bf16/tp1 --tp=1 --medusa_weight_type=bf16 + +# bf16 tp2 +python3 medusa_converter.py --medusa_pt_path=/workdir/medusa-vicuna-13b-v1.3/medusa_lm_head.pt --medusa_output_path=/workdir/medusa_output/bf16/tp2 --tp=2 --medusa_weight_type=bf16 +``` diff --git a/src/turbomind/models/medusa_plugin/tests/medusa_converter.py b/src/turbomind/models/medusa_plugin/tests/medusa_converter.py new file mode 100644 index 0000000000..8783ae9242 --- /dev/null +++ b/src/turbomind/models/medusa_plugin/tests/medusa_converter.py @@ -0,0 +1,111 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import logging +import os + +import fire +import torch + + +class MedusaConverter(): + + def __init__(self, medusa_pt_path: str, medusa_output_path: str, + medusa_num_heads: int, medusa_num_layers: int, + medusa_weight_type: str, tp: int): + logging.basicConfig(level=logging.INFO) + if not os.path.isfile(medusa_pt_path): + logging.error(f'{medusa_pt_path} not exist') + os._exit(os.EX_IOERR) + self.medusa_pt_path = medusa_pt_path + self.medusa_output_path = medusa_output_path + if not os.path.exists(self.medusa_output_path): + os.makedirs(self.medusa_output_path) + + self.medusa_num_heads = medusa_num_heads + self.medusa_num_layers = medusa_num_layers + + self.medusa_weight_type = medusa_weight_type + self.tp = tp + + self.medusa_weights = torch.load(medusa_pt_path) + + def _tp_split(self, tensor: torch.Tensor, tp: int, + dim: int) -> torch.Tensor: + split_size = tensor.shape[dim] // tp + split_tensors = torch.split(tensor, split_size, dim=dim) + return split_tensors + + def _export(self, tensor: torch.Tensor, save_name: str): + if tensor.dtype == torch.bfloat16: + if self.medusa_weight_type == 'fp16': + tensor = tensor.to(torch.float16) + elif self.medusa_weight_type == 'bf16': + # numpy workaround + tensor = tensor.view(torch.float16) + else: + logging.error(f'{self.medusa_weight_type} not support') + os._exit(os.EX_CONFIG) + + tensor.contiguous().cpu().numpy().tofile( + os.path.join(self.medusa_output_path, save_name)) + logging.info( + f'saved to {os.path.join(self.medusa_output_path, save_name)}') + + def _convert_head(self, medusa_head: int, tp: int): + for medusa_layer in range(self.medusa_num_layers): + w_name = f'{medusa_head}.{medusa_layer}.linear.weight' + b_name = f'{medusa_head}.{medusa_layer}.linear.bias' + + tensor_w = self.medusa_weights[w_name] + tensor_b = self.medusa_weights[b_name] + + tensor_w = tensor_w.t() + split_tensors_w = self._tp_split(tensor_w, tp, -1) + split_tensors_b = self._tp_split(tensor_b, tp, -1) + + for rank, split_tensor in enumerate(split_tensors_w): + w_name_after = f'medusa.{medusa_head}.{medusa_layer}.linear.{rank}.weight' # noqa: E501 + logging.info( + f'{w_name}->{w_name_after}, shape:{self.medusa_weights[w_name].shape}->{split_tensor.shape}' # noqa: E501 + ) + self._export(split_tensor, w_name_after) + + for rank, split_tensor in enumerate(split_tensors_b): + b_name_after = f'medusa.{medusa_head}.{medusa_layer}.linear.{rank}.bias' # noqa: E501 + logging.info( + f'{b_name}->{b_name_after}, shape:{self.medusa_weights[b_name].shape}->{split_tensor.shape}' # noqa: E501 + ) + self._export(split_tensor, b_name_after) + + w_name = f'{medusa_head}.{self.medusa_num_layers}.weight' + tensor_w = self.medusa_weights[w_name] + + tensor_w = tensor_w.t() + split_tensors_w = self._tp_split(tensor_w, tp, 0) + + for rank, split_tensor in enumerate(split_tensors_w): + w_name_after = f'medusa.{medusa_head}.{self.medusa_num_layers}.{rank}.weight' # noqa: E501 + logging.info( + f'{w_name}->{w_name_after}, shape:{self.medusa_weights[w_name].shape}->{split_tensor.shape}' # noqa: E501 + ) + self._export(split_tensor, w_name_after) + + def convert(self): + for i in range(self.medusa_num_heads): + self._convert_head(medusa_head=i, tp=self.tp) + + +def main(medusa_pt_path='/workdir/medusa-vicuna-13b-v1.3/medusa_lm_head.pt', + medusa_output_path='/workdir/medusa_output/fp16/tp1', + medusa_num_heads=5, + medusa_num_layers=1, + medusa_weight_type='fp16', + tp=1): + converter = MedusaConverter(medusa_pt_path, medusa_output_path, + medusa_num_heads, medusa_num_layers, + medusa_weight_type, tp) + converter.convert() + + +if __name__ == '__main__': + fire.Fire(main)