From 3913eadfbb60c2d73832f25415d5ebaef62280a4 Mon Sep 17 00:00:00 2001 From: q yao Date: Wed, 27 Nov 2024 19:17:16 +0800 Subject: [PATCH 01/14] disable prefix-caching for vl model (#2825) --- lmdeploy/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lmdeploy/api.py b/lmdeploy/api.py index e66d73754..2b4204a53 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -69,7 +69,11 @@ def pipeline(model_path: str, if backend_config is not None else None model_path = get_model(model_path, download_dir, revision) - _, pipeline_class = get_task(model_path) + task, pipeline_class = get_task(model_path) + if task == 'vlm': + if backend_config.enable_prefix_caching: + backend_config.enable_prefix_caching = False + logger.warning('VLM does not support prefix caching.') if type(backend_config) is not PytorchEngineConfig: # set auto backend mode From f88fbc3c31961b1cb159e041dcb657592fb2da21 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 29 Nov 2024 10:37:42 +0800 Subject: [PATCH 02/14] Add DeepSeek-V2 support (#2763) * add qwen2-moe * eliminate `inter_size_` from ffn layer * clean up * fix lint * clean up * layer-wise `inter_size` & `expert_num` * add head dim 192 * refactor weight processing * deepseek-v2-lite * deepseek-v2 * fix lint * fix lint * fix ut * Update config.yaml * Update config.yaml * fix mixtral * fix moe gating & config parsing * fix yarn for deepseek-v2 * fix `copy_from` * fix rms norm, rotary embedding & deepseek v2 attention * remove debug code --------- Co-authored-by: zhulinJulia24 <145004780+zhulinJulia24@users.noreply.github.com> --- autotest/config.yaml | 2 + examples/cpp/llama/llama_triton_example.cc | 4 +- lmdeploy/turbomind/deploy/config.py | 23 +- lmdeploy/turbomind/deploy/converter.py | 7 +- lmdeploy/turbomind/deploy/loader.py | 21 ++ lmdeploy/turbomind/deploy/module.py | 82 ++++- .../turbomind/deploy/source_model/__init__.py | 1 + .../deploy/source_model/deepseek2.py | 134 ++++++++ .../turbomind/deploy/source_model/mixtral.py | 2 +- .../turbomind/deploy/source_model/qwen.py | 2 +- .../turbomind/deploy/target_model/base.py | 20 +- lmdeploy/turbomind/supported_models.py | 1 + src/turbomind/kernels/CMakeLists.txt | 1 + .../kernels/attention/CMakeLists.txt | 2 + src/turbomind/kernels/attention/attention.cu | 6 + .../attention/codegen/attention_sm80_192.cu | 16 + .../attention/codegen/decoding_sm80_192.cu | 20 ++ src/turbomind/kernels/attention/decoding.cu | 17 +- .../kernels/attention/decoding_config.h | 12 +- src/turbomind/kernels/attention/impl_16816.h | 61 ++-- src/turbomind/kernels/attention/impl_81616.h | 2 +- src/turbomind/kernels/attention/impl_simt.h | 14 +- .../kernels/attention/kv_cache_utils_v2.cu | 12 +- .../kernels/attention/mainloop_sm80.h | 17 +- src/turbomind/kernels/attention/reduce.cu | 6 +- .../kernels/attention/reduce_kernel.h | 7 +- .../kernels/attention/rotary_embedding.h | 17 + .../kernels/attention/test_attention.cu | 12 +- src/turbomind/kernels/core/array_ops.h | 2 +- src/turbomind/kernels/core/math.h | 8 + src/turbomind/kernels/core/thread_map.h | 3 +- .../flash_attention2/CMakeLists.txt | 4 +- .../flash_fwd_launch_template.h | 2 +- .../flash_attention2/static_switch.h | 12 + src/turbomind/kernels/gemm/context.h | 13 +- src/turbomind/kernels/gemm/convert_v2.cu | 41 ++- src/turbomind/kernels/gemm/moe_utils_v2.cu | 195 +++++++++-- src/turbomind/kernels/gemm/moe_utils_v2.h | 4 + .../kernels/gemm/test/test_moe_utils.cu | 86 +---- src/turbomind/kernels/gemm/test/testbed.h | 4 +- src/turbomind/kernels/gemm/unpack.cu | 34 +- src/turbomind/kernels/norm/CMakeLists.txt | 5 + src/turbomind/kernels/norm/rms_norm.cu | 235 +++++++++++++ src/turbomind/kernels/norm/rms_norm.h | 21 ++ src/turbomind/models/llama/CMakeLists.txt | 4 +- src/turbomind/models/llama/LlamaBatch.cc | 6 +- .../models/llama/LlamaDecoderLayerWeight.cc | 325 ++++++++---------- .../models/llama/LlamaDecoderLayerWeight.h | 39 +-- src/turbomind/models/llama/LlamaDenseWeight.h | 265 +++++++++----- src/turbomind/models/llama/LlamaFfnLayer.cc | 26 +- src/turbomind/models/llama/LlamaFfnLayer.h | 9 +- src/turbomind/models/llama/LlamaV2.cc | 1 - src/turbomind/models/llama/LlamaV2.h | 1 - src/turbomind/models/llama/LlamaWeight.cc | 99 +++--- src/turbomind/models/llama/LlamaWeight.h | 36 +- src/turbomind/models/llama/llama_gemm.cc | 2 +- src/turbomind/models/llama/llama_kernels.h | 2 +- src/turbomind/models/llama/llama_params.h | 65 +++- src/turbomind/models/llama/llama_utils.cu | 73 ++-- src/turbomind/models/llama/mla_utils.cu | 93 +++++ src/turbomind/models/llama/mla_utils.h | 57 +++ src/turbomind/models/llama/moe_ffn_layer.cc | 74 ++-- src/turbomind/models/llama/moe_ffn_layer.h | 20 +- .../models/llama/unified_attention_layer.cc | 150 ++++++-- .../models/llama/unified_attention_layer.h | 7 +- src/turbomind/models/llama/unified_decoder.cc | 89 ++--- src/turbomind/models/llama/unified_decoder.h | 16 +- src/turbomind/models/llama/weight_type.h | 56 +++ src/turbomind/python/bind.cpp | 48 ++- .../triton_backend/llama/LlamaTritonModel.cc | 80 +++-- .../triton_backend/llama/LlamaTritonModel.h | 3 - src/turbomind/utils/allocator.h | 3 +- src/turbomind/utils/cuda_utils.h | 19 + src/turbomind/utils/memory_utils.cu | 108 +++--- src/turbomind/utils/memory_utils.h | 13 +- 75 files changed, 2118 insertions(+), 861 deletions(-) create mode 100644 lmdeploy/turbomind/deploy/source_model/deepseek2.py create mode 100644 src/turbomind/kernels/attention/codegen/attention_sm80_192.cu create mode 100644 src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu create mode 100644 src/turbomind/kernels/norm/CMakeLists.txt create mode 100644 src/turbomind/kernels/norm/rms_norm.cu create mode 100644 src/turbomind/kernels/norm/rms_norm.h create mode 100644 src/turbomind/models/llama/mla_utils.cu create mode 100644 src/turbomind/models/llama/mla_utils.h create mode 100644 src/turbomind/models/llama/weight_type.h diff --git a/autotest/config.yaml b/autotest/config.yaml index e31a40f0d..88ca7c312 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -62,6 +62,7 @@ turbomind_chat_model: - liuhaotian/llava-v1.6-vicuna-7b - deepseek-ai/deepseek-vl-1.3b-chat - deepseek-ai/deepseek-coder-1.3b-instruct + - deepseek-ai/DeepSeek-V2-Lite-Chat - codellama/CodeLlama-7b-Instruct-hf - THUDM/glm-4-9b-chat - openbmb/MiniCPM-Llama3-V-2_5 @@ -165,6 +166,7 @@ turbomind_quatization: no_awq: - mistralai/Mistral-7B-Instruct-v0.3 - deepseek-ai/deepseek-coder-1.3b-instruct + - deepseek-ai/DeepSeek-V2-Lite-Chat - codellama/CodeLlama-7b-Instruct-hf gptq: - internlm/internlm2_5-7b-chat diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc index b0e513410..1fb5fa096 100644 --- a/examples/cpp/llama/llama_triton_example.cc +++ b/examples/cpp/llama/llama_triton_example.cc @@ -114,14 +114,14 @@ broadCastRequest(const std::vector& v_start_ids, } else { // conditional case. - ft::deviceMalloc(&d_input_ids, size_1, false); + ft::deviceMalloc(&d_input_ids, size_1, nullptr, false); // ft::deviceMalloc(&d_input_lengths, size_2, false); ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1); // ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); } if (!v_input_bad_words.empty()) { - ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); + ft::deviceMalloc(&d_input_bad_words, size_bad_words, nullptr, false); ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); } else { diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index c724b085a..e483500e9 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -2,6 +2,7 @@ import inspect import json from dataclasses import asdict, fields +from typing import List # use pydantic.dataclasses.dataclass to check data type from pydantic.dataclasses import dataclass @@ -43,22 +44,33 @@ class ModelConfig: # of token_embedding embedding_size: int = 0 num_layer: int = None - inter_size: int = None + inter_size: List[int] = None norm_eps: float = None attn_bias: int = 0 start_id: int = None end_id: int = None size_per_head: int = 128 - group_size: int = 0 + group_size: int = 64 weight_type: str = None session_len: int = None tp: int = 1 model_format: str = 'hf' - expert_num: int = 0 + expert_num: List[int] = () expert_inter_size: int = 0 experts_per_token: int = 0 - moe_shared_gate: int = False - moe_norm_topk: int = False + moe_shared_gate: bool = False + norm_topk_prob: bool = False + routed_scale: float = 1.0 + topk_group: int = 1 + topk_method: str = 'greedy' + moe_group_num: int = 1 + # MLA + q_lora_rank: int = 0 + kv_lora_rank: int = 0 + qk_rope_dim: int = 0 + v_head_dim: int = 0 + # tuning + tune_layer_num: int = 1 def verify(self): invalid = {} @@ -72,6 +84,7 @@ def verify(self): class AttentionConfig: rotary_embedding: int = 128 rope_theta: float = 10000.0 + softmax_scale: float = 0 attention_factor: float = None max_position_embeddings: int = 0 original_max_position_embeddings: int = 0 diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 1c847ede0..77f0bc8dc 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -241,11 +241,10 @@ def get_tm_model(model_path, engine_config.model_format = quant_method group_size = _group_size - # Compatible to awq models that are quantized by lmdeploy (<=v0.3.0) - if not group_size: - group_size = 128 - if engine_config.model_format in ['awq', 'gptq']: + # Compatible to awq models that are quantized by lmdeploy (<=v0.3.0) + if not group_size: + group_size = 128 assert group_size == 128, \ f'model format is "{engine_config.model_format}" ' \ f'but group_size is {group_size}. Currently, only 128 ' \ diff --git a/lmdeploy/turbomind/deploy/loader.py b/lmdeploy/turbomind/deploy/loader.py index e3d79b164..94e779b6b 100644 --- a/lmdeploy/turbomind/deploy/loader.py +++ b/lmdeploy/turbomind/deploy/loader.py @@ -88,6 +88,27 @@ def items(self): yield (-1, {k: f.get_tensor(k) for k in misc}) assert not params + # def items(self): + # params = defaultdict(dict) + # for shard in self.shards: + # # with safe_open(shard, 'pt') as f: + # with open(shard, 'rb') as f: + # w = safetensors.torch.load(f.read()) + # misc = [] + # for k in w.keys(): + # match = re.findall(self.pattern, k) + # if not match: + # misc.append(k) + # else: + # idx = int(match[0]) + # param = params[idx] + # param[k] = w[k] + # if len(param) == self.item_count[idx]: + # yield (idx, params.pop(idx)) + # if misc: + # yield (-1, {k: w[k] for k in misc}) + # assert not params + class PytorchLoader(BaseLoader): diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 8d998abe2..52497175e 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -96,10 +96,13 @@ class Ffn(Module): def __init__(self, model: BaseOutputModel): self.model = model self.tp = model.tensor_para_size + # inter_sizes in config are padded and may be different from what's + # in the weights self.inter_size = model.model_config.inter_size self.group_size = max(1, model.model_config.group_size) def _export(self, + inter_size: int, fmt: str, idx: int, w123, @@ -110,11 +113,11 @@ def _export(self, w1, w2, w3 = map(transpose, w123) if not is_lora_a: - w1 = pad_out_dims(w1, self.inter_size) - w3 = pad_out_dims(w3, self.inter_size) + w1 = pad_out_dims(w1, inter_size) + w3 = pad_out_dims(w3, inter_size) if not is_lora_b: group_size = self.group_size if apply_gs else 1 - w2 = pad_in_dims(w2, self.inter_size // group_size) + w2 = pad_in_dims(w2, inter_size // group_size) w1, w2, w3 = map(pack_fn, (w1, w2, w3)) self.model.save_split(w1, @@ -132,7 +135,8 @@ def _export(self, def apply(self, i: int, r: BaseReader): for e in get_params(r.ffn(i, None)): - e(partial(self._export, self._ffn), partial(r.ffn, i), i) + e(partial(self._export, self.inter_size[i], self._ffn), + partial(r.ffn, i), i) class MoeFfn(Ffn): @@ -154,11 +158,13 @@ def __init__(self, model: BaseOutputModel): self.shared_gate = model.model_config.moe_shared_gate def apply(self, i: int, r: BaseReader): + if self.expert_num[i] == 0: + return for p in get_params(r.moe_ffn_expert()): - for e in range(self.expert_num): + for e in range(self.expert_num[i]): fmt = self._moe_ffn_expert.replace('E', str(e)) - p(partial(self._export, fmt), partial(r.moe_ffn_expert, e, i), - i) + p(partial(self._export, self.inter_size, fmt), + partial(r.moe_ffn_expert, e, i), i) gate = transpose(r.moe_ffn_gate(i)) self.model.save_split(gate, self._moe_ffn_gate.format(i)) @@ -218,6 +224,62 @@ def apply(self, i: int, r: BaseReader): e(self._export, partial(r.attn, i), i) +class MLA(Module): + """ + requires: + r.mla(i, kind) + r.mla_norm(i) + """ + + _mla = 'layers.{0}.attention.{1}.{2}' + + def __init__(self, model: BaseOutputModel): + self.model = model + + def _export(self, idx: int, xs, kind: str, pack_fn, **kwargs): + if all(x is None for x in xs): + return + q_a, q_b, q, kv_a, kv_b, o = map(transpose, xs) + + if q is not None: + q_b = q + + cfg = self.model.model_config + + o = o.reshape(cfg.head_num, cfg.v_head_dim, -1) + o = torch.nn.functional.pad( + o, (0, 0, 0, cfg.size_per_head - cfg.v_head_dim, 0, 0)) + o = o.view(cfg.head_num * cfg.size_per_head, cfg.hidden_units) + + if q_a is not None: + self.model.save_split(pack_fn(q_a), + self._mla.format(idx, 'q_a_proj', kind)) + q_b_name = 'q_proj' if q_a is None else 'q_b_proj' + self.model.save_split(pack_fn(q_b), + self._mla.format(idx, q_b_name, kind), + split_dim=-1) + self.model.save_split(pack_fn(kv_a), + self._mla.format(idx, 'kv_a_proj', kind)) + self.model.save_split(pack_fn(kv_b), + self._mla.format(idx, 'kv_b_proj', kind), + split_dim=-1) + self.model.save_split(pack_fn(o), + self._mla.format(idx, 'wo', kind), + split_dim=0) + + _layernorm = 'layers.{0}.attention.{1}_a_layernorm' + + def apply(self, i: int, r: BaseReader): + + for f in get_params(r.attn(i, None), bias=False): + f(self._export, partial(r.mla, i), i) + + q, k = r.mla_norm(i) + if q is not None: + self.model.save_split(q, self._layernorm.format(i, 'q')) + self.model.save_split(k, self._layernorm.format(i, 'kv')) + + class Misc(Module): """ requires: @@ -258,7 +320,11 @@ class Transformer: def __init__(self, model: BaseOutputModel): self.model = model - modules = [Attn, LayerNorm] + modules = [LayerNorm] + if model.model_config.kv_lora_rank: + modules.append(MLA) + else: + modules.append(Attn) if model.model_config.inter_size: modules.append(Ffn) if model.model_config.expert_num: diff --git a/lmdeploy/turbomind/deploy/source_model/__init__.py b/lmdeploy/turbomind/deploy/source_model/__init__.py index de16bdc0a..b9394b124 100644 --- a/lmdeploy/turbomind/deploy/source_model/__init__.py +++ b/lmdeploy/turbomind/deploy/source_model/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .baichuan import Baichuan2Model, BaichuanModel # noqa: F401 +from .deepseek2 import DeepSeek2Model # noqa: F401 from .deepseek_vl import DeepSeekVLModel # noqa: F401 from .glm4 import Glm4Model # noqa: F401 from .internlm2 import InternLM2Model # noqa: F401 diff --git a/lmdeploy/turbomind/deploy/source_model/deepseek2.py b/lmdeploy/turbomind/deploy/source_model/deepseek2.py new file mode 100644 index 000000000..0023f650f --- /dev/null +++ b/lmdeploy/turbomind/deploy/source_model/deepseek2.py @@ -0,0 +1,134 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from .base import INPUT_MODELS +from .llama import LlamaModel, LlamaReader + + +class DeepSeek2Reader(LlamaReader): + + def moe_ffn_gate(self, i): + return self.params.get(f'model.layers.{i}.mlp.gate.weight') + + def moe_ffn_expert(self, e=None, i=None, kind=None): + if not kind: + return self.filter(r'experts') + result = [] + for key in ['gate', 'down', 'up']: + name = f'model.layers.{i}.mlp.experts.{e}.{key}_proj.{kind}' + tensor = self.params.get(name) + tensor = self.transform(tensor, kind) + result.append(tensor) + return (*result, ) + + def _ffn(self, i: int, kind: str): + """Get ffn kind for layer i.""" + if not kind: + return self.filter(r'mlp' if i == 0 else r'shared_expert\.') + result = [] + for key in ['gate', 'down', 'up']: + name = f'model.layers.{i}.mlp.shared_experts.{key}_proj.{kind}' + if i == 0: + name = name.replace('shared_experts.', '') + tensor = self.params.get(name) + tensor = self.transform(tensor, kind) + result.append(tensor) + return (*result, ) + + def mla(self, i: int, kind: str): + if not kind: + return self.filter(r'self_attn.*proj') + result = [] + for key in [ + 'q_a_proj', 'q_b_proj', 'q_proj', 'kv_a_proj_with_mqa', + 'kv_b_proj', 'o_proj' + ]: + tensor = self.params.get( + f'{self.attn_layer_prefix}.{i}.self_attn.{key}.{kind}') + tensor = self.transform(tensor, kind) + result.append(tensor) + return (*result, ) + + def mla_norm(self, i: int): + result = [] + for k in ['q', 'kv']: + name = f'{self.attn_layer_prefix}.{i}.self_attn.{k}_a_layernorm.weight' # noqa: E501 + result.append(self.params.get(name)) + return (*result, ) + + +def get_yarn_params(rope_scaling: dict): + + scaling_factor = float(rope_scaling['factor']) + mscale = rope_scaling['mscale'] + mscale_all_dim = rope_scaling['mscale_all_dim'] + + def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + _mscale = float( + yarn_get_mscale(scaling_factor, mscale) / + yarn_get_mscale(scaling_factor, mscale_all_dim)) + + softmax_scale = 0 + if mscale_all_dim: + scale = yarn_get_mscale(scaling_factor, mscale_all_dim) + softmax_scale = scale * scale + + return _mscale, softmax_scale + + +@INPUT_MODELS.register_module(name='deepseek2') +class DeepSeek2Model(LlamaModel): + + Reader = DeepSeek2Reader + + def tokenizer_info(self): + n_words = self.model_config['vocab_size'] + bos_id = self.model_config['bos_token_id'] + eos_id = self.model_config['eos_token_id'] + return n_words, bos_id, eos_id + + def model_info(self): + cfg = self.model_config + info = super().model_info() + qk_nope_dim = cfg['qk_nope_head_dim'] + qk_rope_dim = cfg['qk_rope_head_dim'] + num_layer = cfg['num_hidden_layers'] + expert_num = cfg['n_routed_experts'] + expert_num = [expert_num] * num_layer + expert_num[0] = 0 + n_shared_experts = cfg['n_shared_experts'] + expert_inter_size = cfg['moe_intermediate_size'] + experts_per_token = cfg['num_experts_per_tok'] + inter_size = [n_shared_experts * expert_inter_size] * num_layer + inter_size[0] = cfg['intermediate_size'] + norm_topk_prob = cfg['norm_topk_prob'] + size_per_head = qk_rope_dim + qk_nope_dim + info.update(kv_lora_rank=cfg['kv_lora_rank'], + q_lora_rank=cfg['q_lora_rank'] or 0, + qk_rope_dim=qk_rope_dim, + v_head_dim=cfg['v_head_dim'], + size_per_head=size_per_head, + rotary_embedding=qk_rope_dim, + expert_num=expert_num, + expert_inter_size=expert_inter_size, + experts_per_token=experts_per_token, + inter_size=inter_size, + norm_topk_prob=norm_topk_prob, + routed_scale=cfg['routed_scaling_factor'], + topk_method=cfg['topk_method'], + topk_group=cfg['topk_group'], + moe_group_num=cfg['n_group'], + tune_layer_num=2) + rope_scaling = cfg.get('rope_scaling') + if rope_scaling and rope_scaling['type'] == 'yarn': + attention_factor, softmax_scale = get_yarn_params(rope_scaling) + softmax_scale *= size_per_head**(-0.5) + info.update(max_position_embeddings=rope_scaling[ + 'original_max_position_embeddings'], + attention_factor=attention_factor, + softmax_scale=softmax_scale) + return info diff --git a/lmdeploy/turbomind/deploy/source_model/mixtral.py b/lmdeploy/turbomind/deploy/source_model/mixtral.py index ff9df2d40..6ac22a658 100644 --- a/lmdeploy/turbomind/deploy/source_model/mixtral.py +++ b/lmdeploy/turbomind/deploy/source_model/mixtral.py @@ -33,6 +33,6 @@ def model_info(self): info['expert_num'] = cfg['num_local_experts'] info['expert_inter_size'] = cfg['intermediate_size'] info['experts_per_token'] = cfg['num_experts_per_tok'] - info['moe_norm_topk'] = True + info['norm_topk_prob'] = True info['inter_size'] = 0 return info diff --git a/lmdeploy/turbomind/deploy/source_model/qwen.py b/lmdeploy/turbomind/deploy/source_model/qwen.py index 772bd0303..637983e8c 100644 --- a/lmdeploy/turbomind/deploy/source_model/qwen.py +++ b/lmdeploy/turbomind/deploy/source_model/qwen.py @@ -178,6 +178,6 @@ def model_info(self): info['experts_per_token'] = cfg['num_experts_per_tok'] info['inter_size'] = cfg['shared_expert_intermediate_size'] info['moe_shared_gate'] = True - info['moe_norm_topk_prob'] = cfg['norm_topk_prob'] + info['norm_topk_prob'] = cfg['norm_topk_prob'] info['attn_bias'] = 1 return info diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 09699ade0..f2c981bb2 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from abc import ABC +from collections.abc import Sequence import torch import tqdm @@ -65,13 +66,14 @@ def __init__(self, # get `model_info` and `tokenizer_info` at first, which # will be updated to `self.model_config` and `self.attention_config` self.input_model_info = self.input_model.model_info() + self.input_model_info = self.single_to_list( + self.input_model_info, keys=['inter_size', 'expert_num']) self.input_model_tokenizer_info = self.input_model.tokenizer_info() self.permute_qk = self.input_model_info.get('permute_qk', True) - self.update_model_config() - self.model_config.inter_size = _pad_inter_size( - self.model_config.inter_size, self.model_config.group_size, - self.tensor_para_size) + for i, v in enumerate(self.model_config.inter_size): + self.model_config.inter_size[i] = _pad_inter_size( + v, self.model_config.group_size, self.tensor_para_size) if self.model_config.expert_num: self.model_config.expert_inter_size = _pad_inter_size( self.model_config.expert_inter_size, @@ -79,11 +81,21 @@ def __init__(self, self.model_config.verify() assert self.model_config.kv_head_num % self.tensor_para_size == 0 + # print(self.model_config) + self.update_attention_config() self.update_lora_config() # ! Dependency on `self` self.model = model_cls(self) + def single_to_list(self, config: dict, keys): + num_layer = int(config['num_layer']) + for k in keys: + v = config.get(k, None) + if v is not None and not isinstance(v, Sequence): + config[k] = [v] * num_layer + return config + def update_model_config(self): """Update `self.model_config` according to the input_model's `tokenizer_info` and `model_info`""" diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index e66da22df..11e99edfa 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -33,6 +33,7 @@ InternVLChatModel='internvl', # deepseek-vl MultiModalityCausalLM='deepseekvl', + DeepseekV2ForCausalLM='deepseek2', # MiniCPMV MiniCPMV='minicpmv', # mini gemini diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt index febb8692d..40a48402a 100644 --- a/src/turbomind/kernels/CMakeLists.txt +++ b/src/turbomind/kernels/CMakeLists.txt @@ -68,3 +68,4 @@ endif () add_subdirectory(attention) add_subdirectory(gemm) +add_subdirectory(norm) diff --git a/src/turbomind/kernels/attention/CMakeLists.txt b/src/turbomind/kernels/attention/CMakeLists.txt index af9d47e0e..32de38981 100644 --- a/src/turbomind/kernels/attention/CMakeLists.txt +++ b/src/turbomind/kernels/attention/CMakeLists.txt @@ -38,6 +38,8 @@ add_library(attention STATIC codegen/decoding_sm80_64_f16_f16.cu codegen/decoding_sm80_64_f16_u4.cu codegen/decoding_sm80_64_f16_u8.cu + codegen/attention_sm80_192.cu + codegen/decoding_sm80_192.cu ) set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/kernels/attention/attention.cu b/src/turbomind/kernels/attention/attention.cu index 3f557234b..e7642584c 100644 --- a/src/turbomind/kernels/attention/attention.cu +++ b/src/turbomind/kernels/attention/attention.cu @@ -46,6 +46,12 @@ void dispatchAttention(const AttentionParams& params) else if (params.size_per_head == 128) { return dispatch(std::integral_constant{}); } + + if (params.size_per_head == 192) { + using Config = AttentionConfig; + return invokeAttention(params); + } + FT_CHECK(0); } diff --git a/src/turbomind/kernels/attention/codegen/attention_sm80_192.cu b/src/turbomind/kernels/attention/codegen/attention_sm80_192.cu new file mode 100644 index 000000000..ceeafa7a6 --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/attention_sm80_192.cu @@ -0,0 +1,16 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../attention_config.h" +#include "../attention_template.h" + +namespace turbomind { + +using namespace attention; + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +template void invokeAttention::Kernel>( + const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu b/src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu new file mode 100644 index 000000000..214e6748d --- /dev/null +++ b/src/turbomind/kernels/attention/codegen/decoding_sm80_192.cu @@ -0,0 +1,20 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "../decoding_config.h" +#include "../decoding_template.h" + +namespace turbomind { + +using namespace attention; + +template bool +invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +template bool +invokeDecoding>(const AttentionParams& params); + +template bool invokeDecoding>(const AttentionParams& params); + +} // namespace turbomind diff --git a/src/turbomind/kernels/attention/decoding.cu b/src/turbomind/kernels/attention/decoding.cu index 1b04b7d4e..67bd81e45 100644 --- a/src/turbomind/kernels/attention/decoding.cu +++ b/src/turbomind/kernels/attention/decoding.cu @@ -2,8 +2,8 @@ #include "decoding.h" #include "decoding_config.h" +#include "src/turbomind/kernels/attention/arch.h" #include "src/turbomind/models/llama/llama_utils.h" -// #include "src/turbomind/utils/dispatch.h" #include #include @@ -113,6 +113,21 @@ void dispatchDecoding(const AttentionParams& params) return false; }; + if (params.size_per_head == 192) { + + if (is_kv_int8) { + invokeDecoding>(params); + } + else if (is_kv_int4) { + FT_CHECK_WITH_INFO(!is_kv_int4, "not implemented"); + // invokeDecoding>(params); + } + else { + invokeDecoding>(params); + } + return; + } + auto success = dispatch(); FT_CHECK(success); diff --git a/src/turbomind/kernels/attention/decoding_config.h b/src/turbomind/kernels/attention/decoding_config.h index 7dcb119cf..dfd5e0783 100644 --- a/src/turbomind/kernels/attention/decoding_config.h +++ b/src/turbomind/kernels/attention/decoding_config.h @@ -40,7 +40,7 @@ struct DecodingConfig 2) }; template -struct DecodingConfig { +struct DecodingConfig> { static constexpr int Qh = (Qh_ + 7) / 8 * 8; using Attention = Impl; using CacheIter = GetBlockIterFactory; @@ -76,4 +76,14 @@ struct DecodingConfig { using Kernel = AttentionUniversal, CacheIter, DecodingCtaMap>; }; +template +struct DecodingConfig { + static constexpr int Qh = 1; + static constexpr int HeadDim = 192; + + using Attention = Impl; + using CacheIter = GetBlockIterFactory; + using Kernel = AttentionUniversal, Attention>, CacheIter, DecodingCtaMap>; +}; + } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/impl_16816.h b/src/turbomind/kernels/attention/impl_16816.h index 6e8f37f4d..07c7dcb12 100644 --- a/src/turbomind/kernels/attention/impl_16816.h +++ b/src/turbomind/kernels/attention/impl_16816.h @@ -63,26 +63,28 @@ struct Impl>, SmemLayoutV2>>; - using SmemLayoutK = std::conditional_t>, SmemLayoutV2>>; - using SmemLayoutV = std::conditional_t>, SmemLayoutV2>>; using SmemLayoutKVp = void; + static constexpr bool kUseSmemQ = false; + static constexpr bool kUseSmemP = false; + + static_assert(!kUseSmemQ, "current smemQ impl yields inconsistent outputs"); + union SharedStorage { __align__(16) T KV[Stages * (SmemLayoutK::kSize + SmemLayoutV::kSize) / 2]; __align__(16) T Q[SmemLayoutQ::kSize]; }; - static constexpr bool kUseSmemQ = false; - static constexpr bool kUseSmemP = false; - using ThreadMapQ = RakedThreadMap; using ThreadMapKV = RakedThreadMap; @@ -109,22 +111,24 @@ struct Impl sQ{smem_Q}; + SmemAccessor sQ{smem_Q}; - // Load from shared memory using LDSM, rearrange to m16n8k16 atom layout - PRAGMA_UNROLL - for (int m = 0; m < K_M; ++m) { + // Load from shared memory using LDSM, rearrange to m16n8k16 atom layout PRAGMA_UNROLL - for (int k = 0; k < K_K; ++k) { - const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q; - const int di = lane_id / 16 * 8 + k * 16; - ldsm_x4((Array&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di))); + for (int m = 0; m < K_M; ++m) { + PRAGMA_UNROLL + for (int k = 0; k < K_K; ++k) { + const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q; + const int di = lane_id / 16 * 8 + k * 16; + ldsm_x4((Array&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di))); + } } } - if constexpr (kUseSmemQ) { + if constexpr (0) { __syncthreads(); // Rearrange Q in smem so that swizzling is not needed for later LDSMs @@ -142,20 +146,25 @@ struct Impl smem_K; + T* smem_Q; FragQ frag_Q; FragK frag_K; __device__ StateQK(SharedStorage& storage, FragQ frag_Q_): smem_K{storage.KV} { - static_assert(!kUseSmemQ, "not implemented"); - PRAGMA_UNROLL - for (int k = 0; k < K_K; ++k) { + if constexpr (!kUseSmemQ) { PRAGMA_UNROLL - for (int m = 0; m < K_M; ++m) { - frag_Q[k][m] = frag_Q_[k][m]; + for (int k = 0; k < K_K; ++k) { + PRAGMA_UNROLL + for (int m = 0; m < K_M; ++m) { + frag_Q[k][m] = frag_Q_[k][m]; + } } } + else { + smem_Q = storage.Q; + } } __device__ void Load(int k, int pipe_iter) @@ -166,6 +175,16 @@ struct Impl sQ{smem_Q}; + PRAGMA_UNROLL + for (int m = 0; m < K_M; ++m) { + const int qi = lane_id % 16 * 1 + m * 16 + warp_id * WARP_Q; + const int di = lane_id / 16 * 8 + k * 16; + ldsm_x4((Array&)frag_Q[k][m], cast_smem_ptr_to_uint(&sQ(qi, di))); + } + } PRAGMA_UNROLL for (int n = 0; n < K_N; n += 2) { // Load (s16,d16) tiles const int s = n * 8 + offset_s; diff --git a/src/turbomind/kernels/attention/impl_81616.h b/src/turbomind/kernels/attention/impl_81616.h index 3b90bcdf5..f865f1bc3 100644 --- a/src/turbomind/kernels/attention/impl_81616.h +++ b/src/turbomind/kernels/attention/impl_81616.h @@ -104,7 +104,7 @@ struct Impl) { - return std::conditional_t>, SmemLayoutV2>>{}; } diff --git a/src/turbomind/kernels/attention/impl_simt.h b/src/turbomind/kernels/attention/impl_simt.h index a886185a4..444b67e2c 100644 --- a/src/turbomind/kernels/attention/impl_simt.h +++ b/src/turbomind/kernels/attention/impl_simt.h @@ -2,12 +2,16 @@ #pragma once -#include "src/turbomind/kernels/attention/impl.h" +#include +#include +#include + #include "src/turbomind/kernels/core/array_ops.h" #include "src/turbomind/kernels/core/layout.h" #include "src/turbomind/kernels/core/thread_map.h" -#include -#include + +#include "src/turbomind/kernels/attention/impl.h" +#include "src/turbomind/kernels/attention/quantization.h" namespace turbomind::attention { @@ -51,7 +55,7 @@ struct Impl), K_K); }; struct LinearD { diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 20bb00fde..f2e2faef9 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -277,11 +277,14 @@ void invokeProcessKV_v2(char** blocks, }; auto dispatch = [&](auto tkv) { - if (head_dim == 128) { + if (head_dim == 64) { + return invoke(tkv, std::integral_constant{}); + } + else if (head_dim == 128) { return invoke(tkv, std::integral_constant{}); } - else if (head_dim == 64) { - return invoke(tkv, std::integral_constant{}); + else if (head_dim == 192) { + return invoke(tkv, std::integral_constant{}); } FT_CHECK(0); }; @@ -545,6 +548,9 @@ void invokeFlattenKV_v2(T* k, else if (head_dim == 128) { return invoke(tkv, std::integral_constant{}); } + else if (head_dim == 192) { + return invoke(tkv, std::integral_constant{}); + } FT_CHECK(0); }; diff --git a/src/turbomind/kernels/attention/mainloop_sm80.h b/src/turbomind/kernels/attention/mainloop_sm80.h index bf0fc1d32..4435400b7 100644 --- a/src/turbomind/kernels/attention/mainloop_sm80.h +++ b/src/turbomind/kernels/attention/mainloop_sm80.h @@ -52,7 +52,7 @@ struct Mainloop, Impl_> { template __device__ void operator()(Args&&... args) { - Run(Sm80_CpAsync{}, ((Args &&) args)...); + Run(Sm80_CpAsync{}, std::integral_constant{}, ((Args &&) args)...); } template @@ -81,8 +81,9 @@ struct Mainloop, Impl_> { } } - template + template __device__ void Run(Sm80_CpAsync, + std::integral_constant, FragQ& frag_Q, CacheIter& cache_iter, FragO& frag_O, @@ -199,9 +200,10 @@ struct Mainloop, Impl_> { __pipeline_wait_prior(0); } -#if 0 + // #if 1 template __device__ void Run(Sm80_CpAsync<2>, + std::integral_constant, FragQ& frag_Q, CacheIter& cache_iter, FragO& frag_O, @@ -234,7 +236,7 @@ struct Mainloop, Impl_> { Wait(); state_QK.Load(0, 0); - constexpr auto _ = [](int){}; + constexpr auto _ = [](int) {}; auto loop = [&](auto is_residue, auto is_mask) { const int offset_K = tile_iter * CTA_S; @@ -292,14 +294,15 @@ struct Mainloop, Impl_> { __pipeline_wait_prior(0); } -#elif 1 + // #elif 1 // Load : K0,K1 | V0,K2,V1,K3 ... // Compute : K0 | K1,V0,K2,V1 ... // - more register consumption // - more interleaved HMMA and FMA // - slight performance gain - template + template __device__ void Run(Sm80_CpAsync<2>, + std::integral_constant, FragQ& frag_Q, CacheIter& cache_iter_, FragO& frag_O, @@ -407,7 +410,7 @@ struct Mainloop, Impl_> { __pipeline_commit(); __pipeline_wait_prior(0); } -#endif + // #endif __device__ void Wait() { diff --git a/src/turbomind/kernels/attention/reduce.cu b/src/turbomind/kernels/attention/reduce.cu index 12f6aff38..c654f40d0 100644 --- a/src/turbomind/kernels/attention/reduce.cu +++ b/src/turbomind/kernels/attention/reduce.cu @@ -66,12 +66,14 @@ void invokeReduce(T* out, float exp_scale, \ cudaStream_t stream); -INSTANTIATE_invokeReduce(128, half); INSTANTIATE_invokeReduce(64, half); +INSTANTIATE_invokeReduce(128, half); +INSTANTIATE_invokeReduce(192, half); #if ENABLE_BF16 +INSTANTIATE_invokeReduce(64, nv_bfloat16); INSTANTIATE_invokeReduce(128, nv_bfloat16); -INSTANTIATE_invokeReduce(64, nv_bfloat16) +INSTANTIATE_invokeReduce(192, nv_bfloat16); #endif } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/reduce_kernel.h b/src/turbomind/kernels/attention/reduce_kernel.h index 88a3ab3af..b4c9064cf 100644 --- a/src/turbomind/kernels/attention/reduce_kernel.h +++ b/src/turbomind/kernels/attention/reduce_kernel.h @@ -128,9 +128,12 @@ struct Reduce { __syncthreads(); - constexpr int kVecSize = HeadDim / WARP_SIZE; + // HeadDim / WARP_SIZE + // 128 -> 4 + // 64, 192 -> 2 + constexpr int kVecSize = HeadDim % 128 == 0 ? 4 : 2; - using Map = RakedThreadMap; + using Map = RakedThreadMap; static_assert(Map::kIterS == CTA_H); diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index 8e09da22c..db836ed18 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -131,6 +131,7 @@ struct FastRoPE { template __device__ void apply(Array& x, float timestep) { +#if 0 PRAGMA_UNROLL for (int i = 0; i < N; i += 2) { float c, s; @@ -144,6 +145,22 @@ struct FastRoPE { x[i + 1] = (T)tmp1; } } +#else + // Most models apply rotary embedding in half precision + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + float c, s; + sincosf(timestep * inv_freq_[i / 2], &s, &c); + s *= attention_scaling_; + c *= attention_scaling_; + T tmp0 = (T)c * x[i] - (T)s * x[i + 1]; + T tmp1 = (T)c * x[i + 1] + (T)s * x[i]; + if (is_valid_) { + x[i] = tmp0; + x[i + 1] = tmp1; + } + } +#endif } }; diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index c6d7b4063..804d4815d 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -218,14 +218,14 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, #define KV_INT4 0 -#define DECODING 1 +#define DECODING 0 template int test_attention() { AttentionParams params{}; - constexpr size_t kHeadDim = 128; + constexpr size_t kHeadDim = 192; #if DECODING // constexpr size_t kHeadNum = 32; @@ -239,11 +239,11 @@ int test_attention() // constexpr size_t kSequenceLen = 511; // constexpr size_t kSequenceLen = 2047; // constexpr size_t kSequenceLen = 4095; - // constexpr size_t kSequenceLen = 8191; + constexpr size_t kSequenceLen = 8191; // constexpr size_t kSequenceLen = 32767; // constexpr size_t kSequenceLen = 65535; // constexpr size_t kSequenceLen = 131071; - constexpr size_t kSequenceLen = 200000; + // constexpr size_t kSequenceLen = 200000; // constexpr size_t kSequenceLen = 262143; // constexpr size_t kSequenceLen = (1 << 20) - 1; // 1M // constexpr size_t kSequenceLen = (1 << 22) - 1; // 4M @@ -451,6 +451,10 @@ int test_attention() params.qk = qk_buf.data().get(); params.pr = pr_buf.data().get(); + params.attention_scaling = 1.f; + params.llama3_inv_scaling_factor = 0; + params.yarn_ramp_inv_factor_div_2 = 0; + Reference reference(kDump ? Reference::kUNFUSED : Reference::kFLASH_ATTENTION, {}); // Reference reference(Reference::kUNFUSED, {}); reference.Reshape(kInputLen, kContextLen, kHeadNum, kHeadDim, KvHeadNum, kBatchSize); diff --git a/src/turbomind/kernels/core/array_ops.h b/src/turbomind/kernels/core/array_ops.h index 6b639abc8..ec6e7fb4e 100644 --- a/src/turbomind/kernels/core/array_ops.h +++ b/src/turbomind/kernels/core/array_ops.h @@ -172,7 +172,7 @@ inline __device__ void copy(const Array (&src)[M], Array (&dst)[M]) } template -inline __device__ void Store(T* __restrict__ dst, const Array& src) +inline __device__ void Store(T* dst, const Array& src) { if constexpr (sizeof(Array) == sizeof(uint4)) { *(uint4*)dst = (const uint4&)src; diff --git a/src/turbomind/kernels/core/math.h b/src/turbomind/kernels/core/math.h index a708a3498..054269c27 100644 --- a/src/turbomind/kernels/core/math.h +++ b/src/turbomind/kernels/core/math.h @@ -5,6 +5,7 @@ #include "src/turbomind/kernels/core/common.h" #include #include +#include namespace turbomind { @@ -41,6 +42,13 @@ TM_HOST_DEVICE constexpr T log2(T x) // static_assert(log2(32) == 5); // static_assert(log2(1) == 0); +template +TM_HOST_DEVICE constexpr T lowbit(T x) +{ + const std::make_signed_t s = x; + return static_cast(s & -s); +} + // https://arxiv.org/abs/1902.01961 template struct FastDivMod { diff --git a/src/turbomind/kernels/core/thread_map.h b/src/turbomind/kernels/core/thread_map.h index 66b691832..1271aefcc 100644 --- a/src/turbomind/kernels/core/thread_map.h +++ b/src/turbomind/kernels/core/thread_map.h @@ -3,6 +3,7 @@ #pragma once #include "src/turbomind/kernels/core/common.h" +#include "src/turbomind/kernels/core/math.h" #include @@ -51,7 +52,7 @@ struct ThreadMapQ { } }; -template +template struct RakedThreadMap { static constexpr int kDimC = DimC; static constexpr int kDimS = DimS; diff --git a/src/turbomind/kernels/flash_attention/flash_attention2/CMakeLists.txt b/src/turbomind/kernels/flash_attention/flash_attention2/CMakeLists.txt index d41c391e9..81c975058 100644 --- a/src/turbomind/kernels/flash_attention/flash_attention2/CMakeLists.txt +++ b/src/turbomind/kernels/flash_attention/flash_attention2/CMakeLists.txt @@ -8,9 +8,11 @@ add_library(${PROJECT_NAME} STATIC # flash_fwd_hdim64_fp16_sm80.cu flash_fwd_hdim128_fp16_sm80.cu flash_fwd_hdim128_bf16_sm80.cu - # flash_fwd_hdim256_fp16_sm80.cu + flash_fwd_hdim256_bf16_sm80.cu + flash_fwd_hdim256_fp16_sm80.cu ) target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include) target_link_libraries(${PROJECT_NAME} PRIVATE nvidia::cutlass::cutlass) + set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/kernels/flash_attention/flash_attention2/flash_fwd_launch_template.h b/src/turbomind/kernels/flash_attention/flash_attention2/flash_fwd_launch_template.h index e108a55f2..245649636 100644 --- a/src/turbomind/kernels/flash_attention/flash_attention2/flash_fwd_launch_template.h +++ b/src/turbomind/kernels/flash_attention/flash_attention2/flash_fwd_launch_template.h @@ -147,7 +147,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) }); } -#if 0 +#if 1 template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { diff --git a/src/turbomind/kernels/flash_attention/flash_attention2/static_switch.h b/src/turbomind/kernels/flash_attention/flash_attention2/static_switch.h index fd19a0ea6..b1df29cb7 100644 --- a/src/turbomind/kernels/flash_attention/flash_attention2/static_switch.h +++ b/src/turbomind/kernels/flash_attention/flash_attention2/static_switch.h @@ -58,6 +58,18 @@ return __VA_ARGS__(); \ } \ }() +#elif 1 +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } \ + else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() #else #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ diff --git a/src/turbomind/kernels/gemm/context.h b/src/turbomind/kernels/gemm/context.h index 4fec5b732..bd03917b8 100644 --- a/src/turbomind/kernels/gemm/context.h +++ b/src/turbomind/kernels/gemm/context.h @@ -113,12 +113,7 @@ class DynamicGemmContext: public StaticGemmContext { class MoeGemmContext: public Context { public: - MoeGemmContext(int experts, - int experts_per_token, - // int output_dims, - // int input_dims, - const cudaDeviceProp& prop, - cudaStream_t stream); + MoeGemmContext(int experts, int experts_per_token, const cudaDeviceProp& prop, cudaStream_t stream); ~MoeGemmContext() override; @@ -156,9 +151,11 @@ class MoeGemmContext: public Context { Tape Schedule(const LaunchSpec&) override; - void set_offsets(const int* offsets) + void update(int expert_num, int experts_per_token, const int* offsets) { - offsets_ = offsets; + expert_num_ = expert_num; + experts_per_token_ = experts_per_token; + offsets_ = offsets; } protected: diff --git a/src/turbomind/kernels/gemm/convert_v2.cu b/src/turbomind/kernels/gemm/convert_v2.cu index ed8b2ee2f..e58bfc9b9 100644 --- a/src/turbomind/kernels/gemm/convert_v2.cu +++ b/src/turbomind/kernels/gemm/convert_v2.cu @@ -279,17 +279,44 @@ get_weight_and_scales_layout(DataType dtype, bool is_fused_moe, int sm, bool for return {}; } -void* make_blocked_ptrs(const std::vector>& ptrs, cudaStream_t stream) +namespace { + +template +struct Param { + StridedPtr data[N]; + StridedPtr* ptr; + int n; +}; + +template +__global__ void fill_strided_ptrs(Param param) { - std::vector tmp; - for (const auto& [p, s] : ptrs) { - tmp.push_back({p, s}); + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < param.n) { + param.ptr[idx] = param.data[idx]; } +} + +} // namespace + +void* make_blocked_ptrs(const std::vector>& ptrs, cudaStream_t stream) +{ + constexpr int N = 64; + Param param{}; + static_assert(sizeof(param) <= 4096); // max parameter size for cuda11 StridedPtr* ptr{}; cudaMallocAsync(&ptr, sizeof(StridedPtr) * ptrs.size(), stream); - cudaMemcpyAsync(ptr, tmp.data(), sizeof(StridedPtr) * ptrs.size(), cudaMemcpyDefault, stream); - // Sync before tmp can be destructed - cudaStreamSynchronize(stream); + param.ptr = ptr; + for (int i = 0; i < (int)ptrs.size(); i += N) { + const int n = std::min(ptrs.size() - i, N); + for (int j = 0; j < n; ++j) { + auto& [p, s] = ptrs[i + j]; + param.data[j] = StridedPtr{p, s}; + } + param.n = n; + fill_strided_ptrs<<<1, N, 0, stream>>>(param); + param.ptr += N; + } return ptr; } diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu index 5912c60a8..a9e4f7da5 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.cu +++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu @@ -264,7 +264,8 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] int token_num_padded, int expert_num, int top_k, - bool norm_topk) + bool norm_topk, + float routed_scale) { constexpr int max_tiles = kMoeGateMaxTiles; constexpr int threads_per_token = max_expert_num / items_per_thread; // 8 @@ -286,8 +287,8 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] const int warp_ti = threadIdx.x % WARP_SIZE / threads_per_token; - const int warp_offset = thread_idx / WARP_SIZE * WARP_SIZE / threads_per_token; - const int block_offset = thread_idx / block_dim * block_dim / threads_per_token; + // const int warp_offset = thread_idx / WARP_SIZE * WARP_SIZE / threads_per_token; + // const int block_offset = thread_idx / block_dim * block_dim / threads_per_token; float data[items_per_thread]; int idxs[items_per_thread]; @@ -413,7 +414,13 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] #endif - constexpr float kLog2e = 1.4426950408889634074; + // constexpr float kLog2e = 1.4426950408889634074; + // if (k == 0) { + // PRAGMA_UNROLL + // for (int i = 0; i < items_per_thread; ++i) { + // data[i] *= kLog2e; + // } + // } unsigned mask = (unsigned)-1; float max_logit; @@ -437,13 +444,6 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] asm("shl.b32 %0, %1, 1;\n" : "=r"(bit) : "r"(bit)); } - if (k == 0) { - PRAGMA_UNROLL - for (int i = 0; i < items_per_thread; ++i) { - data[i] *= kLog2e; - } - } - int g_max_ei = ei; float g_max_val = max_val; if constexpr (threads_per_token > 1) { @@ -486,7 +486,7 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] PRAGMA_UNROLL for (int i = 0; i < items_per_thread; ++i) { if (!norm_topk || used[i]) { - data[i] = exp2f(data[i] - max_logit); + data[i] = expf(data[i] - max_logit); sum_prob += data[i]; } } @@ -515,9 +515,11 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] PRAGMA_UNROLL for (int i = 0; i < max_tiles * max_expert_num; i += block_dim) { - int e = (i + threadIdx.x) % max_expert_num; - int t = (i + threadIdx.x) / max_expert_num; - smem.shared_accum[t][e] = 0; + int e = (i + threadIdx.x) % max_expert_num; + int t = (i + threadIdx.x) / max_expert_num; + if (t < max_tiles) { + smem.shared_accum[t][e] = 0; + } } __syncthreads(); @@ -536,10 +538,8 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n] if (ti2 < token_num && idx < top_k) { masks[expert_id * token_num_padded + ti2] = idx; - scales[idx * token_num + ti2] = scale; + scales[idx * token_num + ti2] = scale * routed_scale; atomicAdd(&smem.shared_accum[ti2 >> log_tile][expert_id], 1); - - // printf("%d %d %f\n", idx, expert_id, scale); } } @@ -569,6 +569,7 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n int experts, // E int experts_per_token, bool norm_topk, + float routed_scale, cudaStream_t st) { constexpr int base_log_tile = 9; @@ -581,14 +582,14 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n // std::cout << log_tile << " " << tiles << "\n"; - auto invoke = [&](auto max_expert_num, auto top_k, auto items_per_thread) { + auto invoke = [&](auto max_expert_num, auto top_k, auto items_per_thread, auto vec_size) { constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value; constexpr int threads = 256; const int blocks = ceil_div(tokens, threads / thrs_per_tok); cudaMemsetAsync(masks, -1, sizeof(int8_t) * experts * tokens_padded, st); - MoeGateKernel_v8 + MoeGateKernel_v8 <<>>( // scales, (int8_t*)masks, @@ -600,28 +601,49 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n tokens_padded, experts, experts_per_token, - norm_topk); + norm_topk, + routed_scale); }; auto fail = [&] { - std::cerr << "unsupported moe config: expert_num=" << experts << ", top_k=" << experts_per_token << "\n"; + std::cerr << __FILE__ << "(" << __LINE__ << "): unsupported moe config: expert_num=" << experts + << ", top_k=" << experts_per_token << "\n"; std::abort(); }; if (experts <= 8) { if (experts_per_token <= 2) { - invoke(_Int<8>, _Int<2>, _Int<8>); + // MoeGateKernel_V2<2, 128><<>>(scales, + // (int8_t*)masks, + // accum, + // logits, + // log_tile, + // tiles, + // tokens, + // tokens_padded, + // experts); + + // std::cout << tokens << " " << experts << " " << experts_per_token << " " << tokens_padded << "\n"; + invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>); } else { - invoke(_Int<8>, _Int<8>, _Int<8>); + invoke(_Int<8>, _Int<8>, _Int<8>, _Int<4>); } } else if (experts <= 64) { if (experts_per_token <= 4) { - invoke(_Int<64>, _Int<4>, _Int<16>); + invoke(_Int<64>, _Int<4>, _Int<16>, _Int<4>); } else if (experts_per_token <= 8) { - invoke(_Int<64>, _Int<8>, _Int<16>); + invoke(_Int<64>, _Int<8>, _Int<16>, _Int<4>); + } + else { + fail(); + } + } + else if (experts <= 160) { + if (experts_per_token <= 8) { + invoke(_Int<160>, _Int<8>, _Int<10>, _Int<2>); } else { fail(); @@ -687,7 +709,8 @@ __global__ void MoeReduceKernel(T* dst, // [ n, d] const int* en2f, // [ e, n] :: (e,n) -> e*n const float* dst_scales, // [n] int dims, - int tokens) + int tokens, + float dst_scale) { using Vec = Array; @@ -695,7 +718,6 @@ __global__ void MoeReduceKernel(T* dst, // [ n, d] auto dst_ptr = (Vec*)dst + dims * ti; - float dst_scale = 0; if (dst_scales) { dst_scale = dst_scales[ti]; dst_scale = fdividef(1.f, 1.f + expf(-dst_scale)); @@ -711,8 +733,9 @@ __global__ void MoeReduceKernel(T* dst, // [ n, d] } for (int i = threadIdx.x; i < dims; i += block_dim) { +#if 1 Array accum{}; - if (dst_scales) { + if (dst_scale) { Vec v; Ldg(v, dst_ptr[i].data()); using namespace ops; @@ -727,6 +750,24 @@ __global__ void MoeReduceKernel(T* dst, // [ n, d] accum = accum + x; } Store(dst_ptr[i].data(), cast(accum)); +#else + Array accum{}; + if (dst_scale) { + Vec v; + Ldg(v, dst_ptr[i].data()); + using namespace ops; + accum = v * (T)dst_scale; + } + PRAGMA_UNROLL + for (int e = 0; e < exp_k; ++e) { + Vec v; + Ldg(v, src_ptr[e][i].data()); + using namespace ops; + const auto x = v * (T)scale[e]; + accum = accum + x; + } + Store(dst_ptr[i].data(), accum); +#endif } } @@ -739,6 +780,7 @@ void invokeMoeReduce(T* dst, int tokens, int experts_per_token, int dims, + float dst_scale, cudaStream_t st) { // std::cout << __PRETTY_FUNCTION__ << std::endl; @@ -754,7 +796,8 @@ void invokeMoeReduce(T* dst, en2f, dst_scales, dims / vec_size, - tokens); + tokens, + dst_scale); }; switch (experts_per_token) { @@ -774,10 +817,11 @@ void invokeMoeReduce(T* dst, } } -template void invokeMoeReduce(half*, const half*, const float*, const int*, const float*, int, int, int, cudaStream_t); -#ifdef ENABLE_BF16 template void -invokeMoeReduce(nv_bfloat16*, const nv_bfloat16*, const float*, const int*, const float*, int, int, int, cudaStream_t); +invokeMoeReduce(half*, const half*, const float*, const int*, const float*, int, int, int, float, cudaStream_t); +#ifdef ENABLE_BF16 +template void invokeMoeReduce( + nv_bfloat16*, const nv_bfloat16*, const float*, const int*, const float*, int, int, int, float, cudaStream_t); #endif std::vector SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g) @@ -833,4 +877,89 @@ std::vector SampleBalanced(int token_num, int expert_num, int exp_per_tok, return ret; } +template +__global__ void MoeMaskTopKGroups(float* logits, int token_num, int expert_num, int top_k) +{ + constexpr int threads_per_token = max_expert_num / items_per_thread; + + static_assert((threads_per_token & (threads_per_token - 1)) == 0); + static_assert(items_per_thread % access_size == 0); + + const int thread_idx = threadIdx.x + blockIdx.x * blockDim.x; + + const int ti = thread_idx / threads_per_token; + const int ei = thread_idx % threads_per_token; + + float data[items_per_thread]; + PRAGMA_UNROLL + for (int i = 0; i < items_per_thread; ++i) { + data[i] = -std::numeric_limits::infinity(); + } + float max_val = -std::numeric_limits::infinity(); + if (ti < token_num) { + PRAGMA_UNROLL + for (int i = 0; i < items_per_thread; i += access_size) { + const int e = ei * items_per_thread + i; + if (e < expert_num) { + Ldg((Array&)data[i], &logits[ti * expert_num + e]); + PRAGMA_UNROLL + for (int c = 0; c < access_size; ++c) { + max_val = fmaxf(max_val, data[i + c]); + } + } + } + } + + const int warp_ti = threadIdx.x % WARP_SIZE / threads_per_token; + const int warp_ti_offset = warp_ti * threads_per_token; + + bool alive = false; + + for (int k = 0; k < top_k; ++k) { + int g_max_ei = ei; + float g_max_val = max_val; + PRAGMA_UNROLL + for (int m = threads_per_token / 2; m >= 1; m /= 2) { + g_max_val = fmaxf(g_max_val, __shfl_xor_sync((uint32_t)-1, g_max_val, m)); + } + // tie breaking + const auto active = __ballot_sync((uint32_t)-1, max_val == g_max_val); + g_max_ei = __ffs(active >> (unsigned)warp_ti_offset) - 1; + if (ei == g_max_ei) { + alive = true; + max_val = -std::numeric_limits::infinity(); + } + } + + if (!alive && ti < token_num) { + Array vec; + fill(vec, -std::numeric_limits::infinity()); + PRAGMA_UNROLL + for (int i = 0; i < items_per_thread; i += access_size) { + const int e = ei * items_per_thread + i; + if (e < expert_num) { + Store(&logits[ti * expert_num + e], vec); + } + } + } +} + +void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st) +{ + auto invoke = [&](auto max_expert_num, auto items_per_thread, auto vec_size) { + constexpr int thrs_per_tok = max_expert_num.value / items_per_thread.value; + constexpr int threads = 256; + const int blocks = ceil_div(token_num, threads / thrs_per_tok); + MoeMaskTopKGroups + <<>>(logits, token_num, expert_num, top_k); + }; + if (expert_num == 160 && group_size == 20) { + return invoke(_Int<160>, _Int<20>, _Int<4>); + } + + std::cerr << __FILE__ << "(" << __LINE__ << "): unsupported moe config: expert_num=" << expert_num + << ", group_size=" << group_size << "\n"; + std::abort(); +} + } // namespace turbomind diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.h b/src/turbomind/kernels/gemm/moe_utils_v2.h index 0e4c36af0..d53de1354 100644 --- a/src/turbomind/kernels/gemm/moe_utils_v2.h +++ b/src/turbomind/kernels/gemm/moe_utils_v2.h @@ -22,6 +22,7 @@ void invokeMoeGate_V2(int* f2n, int experts, int exp_per_tok, bool norm_topk, + float routed_scale, cudaStream_t st); template @@ -54,8 +55,11 @@ void invokeMoeReduce(T* dst, int tokens, int experts_per_token, int dims, + float dst_scale, cudaStream_t st); +void invokeMaskMoeTopKGroups(float* logits, int token_num, int expert_num, int group_size, int top_k, cudaStream_t st); + // Sample `e` from `E` experts uniformly for every token std::vector SampleUniform(int token_num, int expert_num, int exp_per_tok, std::mt19937& g); diff --git a/src/turbomind/kernels/gemm/test/test_moe_utils.cu b/src/turbomind/kernels/gemm/test/test_moe_utils.cu index 47e3bfdb1..4b2ea6a83 100644 --- a/src/turbomind/kernels/gemm/test/test_moe_utils.cu +++ b/src/turbomind/kernels/gemm/test/test_moe_utils.cu @@ -45,72 +45,6 @@ void diff_vecs(const T* data, const T* refs, int m, int k, std::string msg) } } -#if 0 -void func() -{ - using thrust::universal_vector; - - // clang-format off - std::vector h_logits{ - 8, 5, 1, 4, 3, 6, 2, 7, - 50, 60, 90, 20, 70, 71, 72, 73, - 0, 1, 0, 0, 0, 1, 0, 1, - 0, 0, 0, 1, 0, 0, 0, 2}; - // clang-format on - - h_logits.resize(8); - - // auto tmp = h_logits; - // for (int i = 0; i < 127; ++i) { - // h_logits.insert(h_logits.end(), tmp.begin(), tmp.end()); - // } - - universal_vector logits(h_logits.begin(), h_logits.end()); - - const int E = 8; - const int n = h_logits.size() / E; - const int e = 2; - - const int n_padded = (n + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize; - - universal_vector f2n(e * n); - universal_vector en2f(e * n); - universal_vector offsets(E + 1); - universal_vector accum(E * kMoeGateMaxTiles); - universal_vector scales(n * e); - universal_vector masks(E * n_padded); - - for (int i = 0; i < 10; ++i) { - gemm::CacheFlushing::flush(0); - cudaMemset(accum.data().get(), 0, sizeof(int) * accum.size()); - invokeMoeGate_V2(f2n.data().get(), - en2f.data().get(), - offsets.data().get(), - scales.data().get(), - masks.data().get(), - accum.data().get(), - logits.data().get(), - n, - n_padded, - E, - e, - 0); - } - - auto err = cudaDeviceSynchronize(); - if (err) { - std::cerr << cudaGetErrorString(err) << "\n"; - } - - print_vecs(scales.data().get(), e, n, "scales", 12); - print_vecs(masks.data().get(), E, n_padded, "tmp"); - print_vecs(accum.data().get(), E, 1, "accum"); - print_vecs(offsets.data().get(), 1, E + 1, "offsets"); - print_vecs(f2n.data().get(), n * e, 1, "f2n"); - print_vecs(en2f.data().get(), e, n, "en2f"); -} -#endif - RNG& gRNG() { static RNG inst{}; @@ -271,6 +205,8 @@ bool test_moe_gate(int tokens, // cudaMemPrefetchAsync(scales.data().get(), sizeof(float) * scales.size(), 0); cudaMemPrefetchAsync(logits.data().get(), sizeof(float) * logits.size(), 0); + // invokeMaskMoeTopKGroups(logits.data().get(), tokens, expert_num, expert_num / 8, 3, nullptr); + for (int i = 0; i < 1; ++i) { gemm::CacheFlushing::flush(); cudaMemset(accum.data().get(), 0, sizeof(int) * accum.size()); @@ -286,8 +222,9 @@ bool test_moe_gate(int tokens, // tokens_padded, expert_num, experts_per_token, - true, - 0); + false, + 1.f, + nullptr); } // invokeMoeTiling(coords.data().get(), offsets.data().get(), expert_num, coords.size(), &tiling, 1, 0); @@ -334,6 +271,8 @@ bool test_moe_gate(int tokens, // success = false; } + // print_vecs(logits.data().get(), tokens, expert_num, "logits", 12); + if (!success && 1) { diff_vecs(eids.data().get(), eids_ref.data().get(), experts_per_token, tokens, "eids"); @@ -353,6 +292,15 @@ bool test_moe_gate(int tokens, // print_vecs(scales_ref.data().get(), experts_per_token, tokens, "scales_ref", 12); print_vecs(scales.data().get(), experts_per_token, tokens, "scales", 12); + for (int i = 0; i < tokens; ++i) { + float sum = 0; + for (int j = 0; j < experts_per_token; ++j) { + sum += scales[j * tokens + i]; + } + std::cout << sum << " "; + } + std::cout << "\n"; + // print_vecs(accum.data().get(), expert_num, 1, "accum"); // print_vecs(coords.data().get(), 1, max_coords, "coords"); @@ -393,7 +341,7 @@ int main() // test_moe_gate(32768, 64, 8, tape, tiling); // test_moe_gate(8, 60, 4, tape, tiling); - test_moe_gate(65536, 8, 2, tape, tiling); + test_moe_gate(16, 160, 6, tape, tiling); return 0; for (int i = 1; i < 16384; ++i) { diff --git a/src/turbomind/kernels/gemm/test/testbed.h b/src/turbomind/kernels/gemm/test/testbed.h index 7a089fbdf..4747644f9 100644 --- a/src/turbomind/kernels/gemm/test/testbed.h +++ b/src/turbomind/kernels/gemm/test/testbed.h @@ -357,7 +357,7 @@ class Testbed { } } - ((MoeGemmContext*)ctx_.get())->set_offsets(moe_m_offsets_.data().get()); + ((MoeGemmContext*)ctx_.get())->update(experts_, exp_per_tok_, moe_m_offsets_.data().get()); CHECK(batch_dim == 0); CHECK(a_desc_.order == kRowMajor); @@ -518,6 +518,7 @@ class Testbed { batch_size_, expert_ids_.size() / batch_size_, output_dims_, + 0.f, stream_); invokeMoeReduce(c_ref_.data().get(), @@ -528,6 +529,7 @@ class Testbed { batch_size_, expert_ids_.size() / batch_size_, output_dims_, + 0.f, stream_); cudaDeviceSynchronize(); diff --git a/src/turbomind/kernels/gemm/unpack.cu b/src/turbomind/kernels/gemm/unpack.cu index 92f468d82..39e6a2e1a 100644 --- a/src/turbomind/kernels/gemm/unpack.cu +++ b/src/turbomind/kernels/gemm/unpack.cu @@ -71,14 +71,44 @@ void unpack_awq_gemm(uint4_t* dst, const uint4_t* src, int rows, int cols, cudaS permute_u4<0, 1, 3, 2><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape); } +__global__ void transpose_u4_kernel(uint4_t* dst, const uint4_t* src, int s, int c) +{ + const int idx_c = 8 * (threadIdx.x + blockIdx.x * blockDim.x); + const int idx_s = 8 * (threadIdx.y + blockIdx.y * blockDim.y); + if (idx_c >= c || idx_s >= s) { + return; + } + uint32_t ivec[8]; + PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + ivec[i] = ((const uint32_t*)src)[((idx_s + i) * c + idx_c) / 8]; + } + uint32_t ovec[8]{}; + PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + PRAGMA_UNROLL + for (int j = 0; j < 8; ++j) { + ovec[i] |= (((ivec[j] >> (i * 4)) & 0xfu) << (j * 4)); + } + } + PRAGMA_UNROLL + for (int i = 0; i < 8; ++i) { + ((uint32_t*)dst)[((idx_c + i) * s + idx_s) / 8] = ovec[i]; + } +} + void transpose_u4(uint4_t* dst, const uint4_t* src, int s, int c, cudaStream_t st) { if (s % 8 || c % 8) { std::cerr << "transpose_u4: invalid shape (" << s << "," << c << "), must be multiple of 8" << std::endl; return; } - Array shape{s, c}; - permute_u4<1, 0><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape); + // Array shape{s, c}; + // permute_u4<1, 0><<<512, 512, 0, st>>>((uint*)dst, (const uint*)src, shape); + + const dim3 block(16, 16); + const dim3 grid((c + 15) / 16, (s + 15) / 16); + transpose_u4_kernel<<>>(dst, src, s, c); } // load -> unpack -> extend_to_u8 -> manipulation -> compat_to_u4 -> store diff --git a/src/turbomind/kernels/norm/CMakeLists.txt b/src/turbomind/kernels/norm/CMakeLists.txt new file mode 100644 index 000000000..bc1569c40 --- /dev/null +++ b/src/turbomind/kernels/norm/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +add_library(rms_norm rms_norm.cu) +set_property(TARGET rms_norm PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET rms_norm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/kernels/norm/rms_norm.cu b/src/turbomind/kernels/norm/rms_norm.cu new file mode 100644 index 000000000..22fd69f52 --- /dev/null +++ b/src/turbomind/kernels/norm/rms_norm.cu @@ -0,0 +1,235 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "cub/block/block_reduce.cuh" + +#include "src/turbomind/kernels/core/array_ops.h" +#include "src/turbomind/kernels/core/common.h" + +namespace turbomind { + +template +__global__ void RMSNormKernel(T* dst, + int dst_ld, + const T* src, + int src_ld, + const T* __restrict__ weights, + int dims, + int num, + float eps, + float inv_dims) +{ + const int ti = blockIdx.x; + const int di = threadIdx.x * vec_size; + + if (ti >= num) { + return; + } + + src += src_ld * ti; + + Array accum{}; + Array vec; + + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(vec, &src[i]); + Array tmp = cast(vec); + using namespace ops; + accum = accum + tmp * tmp; + } + + float sum{}; + PRAGMA_UNROLL + for (int i = 0; i < vec_size; ++i) { + sum += accum[i]; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + sum = BlockReduce{temp_storage}.Sum(sum); + + __shared__ float shared_sum; + + if (threadIdx.x == 0) { + shared_sum = rsqrtf(sum * inv_dims + eps); + } + + __syncthreads(); + + sum = shared_sum; + + dst += dst_ld * ti; + + Array sv; + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(vec, &src[i]); + Ldg(sv, &weights[i]); + PRAGMA_UNROLL + for (int c = 0; c < vec_size; ++c) { + vec[c] = (T)((float)vec[c] * sum) * sv[c]; + // vec[c] = (T)((float)vec[c] * sum * (float)sv[c]); + } + Store(&dst[i], vec); + } +} + +template +void invokeRMSNorm( + T* dst, int dst_ld, const T* src, int src_ld, const T* weights, int dims, int num, float eps, cudaStream_t st) +{ + constexpr int vec_size = 16 / sizeof(T); + + constexpr int threads = 512; + const int blocks = num; + + RMSNormKernel<<>>(dst, // + dst_ld, + src, + src_ld, + weights, + dims, + num, + eps, + 1.f / dims); +} + +template void invokeRMSNorm(half* dst, + int dst_ld, + const half* src, + int src_ld, + const half* weights, + int dims, + int num, + float eps, + cudaStream_t st); +#if ENABLE_BF16 +template void invokeRMSNorm(nv_bfloat16* dst, + int dst_ld, + const nv_bfloat16* src, + int src_ld, + const nv_bfloat16* weights, + int dims, + int num, + float eps, + cudaStream_t st); +#endif + +// r' <- r + (h + b) +// h' <- norm(r') * w +template +__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual, + T* __restrict__ hidden_states, + const T* __restrict__ weights, + const T* __restrict__ bias, + int dims, + int num, + float eps, + float inv_dims) +{ + const int ti = blockIdx.x; + const int di = threadIdx.x * vec_size; + + if (ti >= num) { + return; + } + + residual += dims * ti; + hidden_states += dims * ti; + + Array accum{}; + + Array r_vec; + Array h_vec; + Array b_vec; + + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(r_vec, &residual[i]); + Load(h_vec, &hidden_states[i]); + + using namespace ops; + r_vec = r_vec + h_vec; + + if (bias) { + Ldg(b_vec, &bias[i]); + r_vec = r_vec + b_vec; + } + + Store(&residual[i], r_vec); + + Array tmp = cast(r_vec); + + accum = accum + tmp * tmp; + } + + float sum{}; + PRAGMA_UNROLL + for (int i = 0; i < vec_size; ++i) { + sum += accum[i]; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + sum = BlockReduce{temp_storage}.Sum(sum); + + __shared__ float shared_sum; + + if (threadIdx.x == 0) { + shared_sum = rsqrtf(sum * inv_dims + eps); + } + + __syncthreads(); + + sum = shared_sum; + + Array w_vec; + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(r_vec, &residual[i]); + Ldg(w_vec, &weights[i]); + PRAGMA_UNROLL + for (int c = 0; c < vec_size; ++c) { + r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c]; + } + Store(&hidden_states[i], r_vec); + } +} + +template +void invokeBiasResidualRMSNorm( + T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st) +{ + constexpr int vec_size = 16 / sizeof(T); + constexpr int threads = 512; + const int blocks = num; + + BiasResidualRMSNormKernel<<>>(residual, // + hidden_states, + weights, + bias, + dims, + num, + eps, + 1.f / dims); +} + +template void invokeBiasResidualRMSNorm(half* residual, + half* hidden_states, + const half* weights, + const half* bias, + int dims, + int num, + float eps, + cudaStream_t st); + +#if ENABLE_BF16 +template void invokeBiasResidualRMSNorm(nv_bfloat16* residual, + nv_bfloat16* hidden_states, + const nv_bfloat16* weights, + const nv_bfloat16* bias, + int dims, + int num, + float eps, + cudaStream_t st); +#endif + +} // namespace turbomind diff --git a/src/turbomind/kernels/norm/rms_norm.h b/src/turbomind/kernels/norm/rms_norm.h new file mode 100644 index 000000000..83fa0f826 --- /dev/null +++ b/src/turbomind/kernels/norm/rms_norm.h @@ -0,0 +1,21 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +namespace turbomind { + +template +void invokeRMSNorm( + T* dst, int dst_ld, const T* src, int src_ld, const T* weights, int dims, int num, float eps, cudaStream_t st); + +template +void invokeRMSNorm(T* dst, const T* src, const T* weights, int dims, int num, float eps, cudaStream_t st) +{ + invokeRMSNorm(dst, dims, src, dims, weights, dims, num, eps, st); +} + +template +void invokeBiasResidualRMSNorm( + T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st); + +} // namespace turbomind diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 285fcea31..3c714bd23 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -20,11 +20,13 @@ add_library(Llama STATIC unified_attention_layer.cc llama_kernels.cu llama_decoder_kernels.cu - llama_utils.cu) + llama_utils.cu + mla_utils.cu) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Llama PUBLIC CUDA::cudart gemm2 + rms_norm cublasMMWrapper DynamicDecodeLayer activation_kernels diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 4138174e5..ea321d06a 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -20,6 +20,7 @@ #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/debug_utils.h" #include "src/turbomind/utils/logger.h" +#include "src/turbomind/utils/nccl_utils.h" #include #include #include @@ -1041,6 +1042,9 @@ LlamaBatch::LlamaBatch(const EngineParam& param, AllocateBuffer(max_batch_size_, session_len_, cache_block_seq_len); AllocatePersistantBuffer(max_batch_size_, cache_block_seq_len); + + // Wait for allocations + check_cuda_error(cudaStreamSynchronize(stream_)); } template @@ -1990,7 +1994,7 @@ void LlamaBatch::tune() nullptr, nullptr); // implicit barrier for TP - check_cuda_error(cudaStreamSynchronize(stream_)); + ftNcclStreamSynchronize(model_->tensor_para_, {}, stream_); } auto tock = std::chrono::steady_clock::now(); diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index f6f9ab0ef..0a2a3be17 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -52,28 +52,21 @@ static bool is_fuse_silu_act() } template -LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, - size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t hidden_units, - size_t inter_size, - WeightType weight_type, - int group_size, - LoraParam lora_param, - bool attn_bias, - MoeParam moe_param, - size_t tensor_para_size, - size_t tensor_para_rank): - head_num_(head_num), - kv_head_num_(kv_head_num), - size_per_head_(size_per_head), - hidden_units_(hidden_units), - inter_size_(inter_size), - weight_type_(weight_type), - attn_bias_(attn_bias), - tensor_para_size_(tensor_para_size), - tensor_para_rank_(tensor_para_rank) +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_id, + const ModelParam& model, + const LoraParam& lora_param, + const MoeParam& moe_param, + size_t tp_size, + size_t tp_rank): + head_num_(model.head_num), + kv_head_num_(model.kv_head_num), + size_per_head_(model.head_dim), + hidden_units_(model.hidden_units), + inter_size_(model.inter_size.at(layer_id)), + weight_type_(model.weight_type), + attn_bias_(model.attn_bias), + tensor_para_size_(tp_size), + tensor_para_rank_(tp_rank) { if (lora_param.policy == LoraPolicy::kPlora) { std::vector keys = { @@ -88,7 +81,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, auto& weight = *weights[i]; int rank = lora_param.r; float scale = lora_param.scale; - std::string full_name = "layers." + std::to_string(layer_idx) + "." + name; + std::string full_name = "layers." + std::to_string(layer_id) + "." + name; for (const auto& [re, pr] : lora_param.rank_pattern) { if (std::regex_search(full_name, pr.first)) { @@ -114,36 +107,44 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, fused_up_and_gate_ = ffn_weights.gating.lora.policy != LoraPolicy::kPlora; - self_attn_weights.qkv.input_dims = hidden_units_; - self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; - self_attn_weights.qkv.type = weight_type; - self_attn_weights.qkv.group_size = group_size; - - self_attn_weights.output.input_dims = (head_num * size_per_head) / tensor_para_size_; - self_attn_weights.output.output_dims = hidden_units_; - self_attn_weights.output.type = weight_type; - self_attn_weights.output.group_size = group_size; + self_attn_weights = LlamaAttentionWeight{hidden_units_, + size_per_head_, + head_num_, + kv_head_num_, + model.mla, + attn_bias_, + tensor_para_size_, + weight_type_, + model.group_size}; ffn_weights = LlamaFfnWeight{ hidden_units_, inter_size_, tensor_para_size_, weight_type_, - group_size, + model.group_size, weight_type_ == WeightType::kINT4 && is_fuse_silu_act(), }; - moe_weights = MoeFfnWeight{hidden_units_, - moe_param.inter_size, - moe_param.expert_num, - moe_param.method, - moe_param.shared_gate, - tensor_para_size_, - weight_type, - group_size, - is_fuse_silu_act()}; - - mallocWeights(); + moe_weights = MoeFfnWeight{ + layer_id, moe_param, hidden_units_, weight_type_, model.group_size, tensor_para_size_, is_fuse_silu_act()}; +} + +template +void LlamaDecoderLayerWeight::malloc(cudaStream_t st) +{ + deviceMalloc((T**)&self_attn_norm_weights, hidden_units_, st); + deviceMalloc((T**)&ffn_norm_weights, hidden_units_, st); + + self_attn_weights.malloc(st); + + if (inter_size_) { + ffn_weights.malloc(st); + } + + if (!moe_weights.experts.empty()) { + moe_weights.malloc(st); + } } template @@ -168,52 +169,6 @@ size_t LlamaDecoderLayerWeight::workspace_size() const noexcept return size * sizeof(uint16_t); } -template -void freeWeights(LlamaDenseWeight& weights) -{ - cudaFree(weights.kernel); - cudaFree(weights.bias); - cudaFree(weights.scales); - cudaFree(weights.zeros); - - weights.kernel = nullptr; - weights.bias = nullptr; - weights.scales = nullptr; - weights.zeros = nullptr; - - { - cudaFree(weights.lora.a); - cudaFree(weights.lora.b); - weights.lora.a = nullptr; - weights.lora.b = nullptr; - } -} - -template -void LlamaDecoderLayerWeight::mallocWeights(LlamaDenseWeight& weights, bool bias) -{ - if (bias) { - deviceMalloc((T**)&weights.bias, weights.output_dims); - } - const size_t bit_size = getBitSize(weights.type); - if (bit_size >= 16) { // fp16, fp32 - deviceMalloc((T**)&weights.kernel, weights.input_dims * weights.output_dims); - } - else { // int8, int4 - const int factor = sizeof(float) * 8 / bit_size; - FT_CHECK(weights.input_dims % factor == 0); - deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor); - deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor); - deviceMalloc((T**)&weights.scales, weights.input_dims / weights.group_size * weights.output_dims); - deviceMalloc((T**)&weights.zeros, weights.input_dims / weights.group_size * weights.output_dims); - } - - if (weights.lora.r > 0) { - deviceMalloc((T**)&weights.lora.a, weights.input_dims * weights.lora.r); - deviceMalloc((T**)&weights.lora.b, weights.lora.r * weights.output_dims); - } -} - template std::string concat(FirstArg&& first, Args&&... args) { @@ -342,64 +297,24 @@ void loadWeights(LlamaDenseWeight& w, std::string prefix, FtCudaDataType mode } template -void LlamaDecoderLayerWeight::mallocWeights() +void LlamaDecoderLayerWeight::free(cudaStream_t st) { - deviceMalloc((T**)&self_attn_norm_weights, hidden_units_); - deviceMalloc((T**)&ffn_norm_weights, hidden_units_); + deviceFree(self_attn_norm_weights, st); + deviceFree(ffn_norm_weights, st); - mallocWeights(self_attn_weights.qkv, attn_bias_); - mallocWeights(self_attn_weights.output, attn_bias_); + self_attn_weights.free(st); if (inter_size_) { - mallocWeights(ffn_weights.gating, false); - mallocWeights(ffn_weights.intermediate, false); - mallocWeights(ffn_weights.output, false); + ffn_weights.free(st); } if (!moe_weights.experts.empty()) { - mallocWeights(moe_weights.gate, false); - for (auto& e : moe_weights.experts) { - mallocWeights(e.gating, false); - mallocWeights(e.intermediate, false); - mallocWeights(e.output, false); - } - if (moe_weights.shared_gate.output_dims) { - mallocWeights(moe_weights.shared_gate, false); - } + moe_weights.free(st); } } template -LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() -{ - cudaFree((void*)self_attn_norm_weights); - cudaFree((void*)ffn_norm_weights); - self_attn_norm_weights = nullptr; - ffn_norm_weights = nullptr; - - freeWeights(self_attn_weights.qkv); - freeWeights(self_attn_weights.output); - - if (inter_size_) { - freeWeights(ffn_weights.fused_gating_intermediate); - freeWeights(ffn_weights.gating); - freeWeights(ffn_weights.intermediate); - freeWeights(ffn_weights.output); - } - - if (!moe_weights.experts.empty()) { - freeWeights(moe_weights.gate); - for (auto& e : moe_weights.experts) { - freeWeights(e.fused_gating_intermediate); - freeWeights(e.gating); - freeWeights(e.intermediate); - freeWeights(e.output); - } - if (moe_weights.shared_gate.kernel) { - freeWeights(moe_weights.shared_gate); - } - } -} +LlamaDecoderLayerWeight::~LlamaDecoderLayerWeight() = default; template void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType model_file_type) @@ -432,6 +347,24 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType } } +template +void getMLATensor(LlamaAttentionWeight& w, const std::string& p, TensorMap& m, int tp_rank) +{ + if (w.q_proj.output_dims) { + getWeightTensor(w.q_proj, false, concat(p, "attention.q_proj", tp_rank), m); + } + else { + getWeightTensor(w.q_a_proj, false, concat(p, "attention.q_a_proj"), m); + getWeightTensor(w.q_b_proj, false, concat(p, "attention.q_b_proj", tp_rank), m); + m.insert(concat(p, "attention.q_a_layernorm"), + Tensor{MEMORY_GPU, getTensorType(), {sizeof(T) * w.q_b_proj.input_dims}, w.q_a_layernorm}); + } + getWeightTensor(w.kv_a_proj, false, concat(p, "attention.kv_a_proj"), m); + getWeightTensor(w.kv_b_proj, false, concat(p, "attention.kv_b_proj", tp_rank), m); + m.insert(concat(p, "attention.kv_a_layernorm"), + Tensor{MEMORY_GPU, getTensorType(), {sizeof(T) * w.kv_b_proj.input_dims}, w.kv_a_layernorm}); +} + template TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix) { @@ -445,7 +378,12 @@ TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix) auto get_prefix = [=](std::string_view name) { return concat(prefix, name, tensor_para_rank_); }; - getWeightTensor(self_attn_weights.qkv, attn_bias_, get_prefix("attention.w_qkv"), output); + if (self_attn_weights.qkv.output_dims) { + getWeightTensor(self_attn_weights.qkv, attn_bias_, get_prefix("attention.w_qkv"), output); + } + else { + getMLATensor(self_attn_weights, prefix, output, tensor_para_rank_); + } getWeightTensor(self_attn_weights.output, attn_bias_, get_prefix("attention.wo"), output); if (inter_size_) { @@ -478,7 +416,8 @@ TensorMap LlamaDecoderLayerWeight::getParams(std::string prefix) } // template -static void convert_u4(LlamaDenseWeight& weight, bool is_fused_moe, void* workspace, size_t size, bool use_simt) +static void convert_u4( + LlamaDenseWeight& weight, bool is_fused_moe, void* workspace, size_t size, bool use_simt, cudaStream_t st) { FT_CHECK(weight.type == WeightType::kINT4); @@ -488,11 +427,11 @@ static void convert_u4(LlamaDenseWeight& weight, bool is_fused_moe, void* get_weight_and_scales_layout(gemm::DataType::U4, is_fused_moe, getSMVersion(), use_simt); if (order_b == kColMajor) { - transpose_u4((uint4_t*)workspace, (const uint4_t*)weight.kernel, weight.input_dims, weight.output_dims); - cudaMemcpy(weight.kernel, workspace, weight.input_dims * weight.output_dims / 2, cudaMemcpyDefault); + transpose_u4((uint4_t*)workspace, (const uint4_t*)weight.kernel, weight.input_dims, weight.output_dims, st); + cudaMemcpyAsync(weight.kernel, workspace, weight.input_dims * weight.output_dims / 2, cudaMemcpyDefault, st); } - extend_to_u16((uint16_t*)workspace, (const uint4_t*)weight.kernel, weight.input_dims * weight.output_dims); + extend_to_u16((uint16_t*)workspace, (const uint4_t*)weight.kernel, weight.input_dims * weight.output_dims, st); sync_check_cuda_error(); MatrixLayout w_desc{ @@ -507,25 +446,22 @@ static void convert_u4(LlamaDenseWeight& weight, bool is_fused_moe, void* k_desc.type = gemm::DataType::U4; k_desc.pack = pack_b; - cudaMemset(weight.kernel, 0, weight.input_dims * weight.output_dims / 2); + cudaMemsetAsync(weight.kernel, 0, weight.input_dims * weight.output_dims / 2, st); - FT_CHECK(Convert(workspace, w_desc, weight.kernel, k_desc, 0) == 0); + FT_CHECK(Convert(workspace, w_desc, weight.kernel, k_desc, st) == 0); sync_check_cuda_error(); const int scale_count = (weight.input_dims / weight.group_size) * weight.output_dims; // std::cout << "fuse_scales_and_zeros\n"; - fuse_scales_and_zeros((half*)workspace, weight.scales, weight.zeros, scale_count); + fuse_scales_and_zeros((half*)workspace, weight.scales, weight.zeros, scale_count, st); // cudaMemset((T*)workspace, 0, sizeof(T) * scale_count * 2); sync_check_cuda_error(); - cudaDeviceSynchronize(); - - cudaFree(weight.scales); - cudaFree(weight.zeros); - weight.scales = weight.zeros = nullptr; + deviceFree(weight.scales, st); + deviceFree(weight.zeros, st); - deviceMalloc((half**)&weight.scales_zeros, scale_count * 2); + deviceMalloc((half**)&weight.scales_zeros, scale_count * 2, st); MatrixLayout s_desc{ gemm::DataType::U32, @@ -538,7 +474,7 @@ static void convert_u4(LlamaDenseWeight& weight, bool is_fused_moe, void* MatrixLayout q_desc = s_desc; q_desc.pack = pack_v; - FT_CHECK(Convert(workspace, s_desc, weight.scales_zeros, q_desc, 0) == 0); + FT_CHECK(Convert(workspace, s_desc, weight.scales_zeros, q_desc, st) == 0); sync_check_cuda_error(); weight.k_desc = k_desc; @@ -548,7 +484,8 @@ static void convert_u4(LlamaDenseWeight& weight, bool is_fused_moe, void* } template -static void convert_fp(LlamaDenseWeight& weight, bool is_fused_moe, void* workspace, size_t size, bool use_simt) +static void +convert_fp(LlamaDenseWeight& weight, bool is_fused_moe, void* workspace, size_t size, bool use_simt, cudaStream_t st) { using namespace gemm; @@ -563,12 +500,13 @@ static void convert_fp(LlamaDenseWeight& weight, bool is_fused_moe, void* wor const int output_dim = weight.output_dims; if (order_b == kColMajor) { - invokeTransposeAxis01((uint16_t*)workspace, (uint16_t*)weight.kernel, input_dim, output_dim, 1, nullptr); + invokeTransposeAxis01((uint16_t*)workspace, (uint16_t*)weight.kernel, input_dim, output_dim, 1, st); sync_check_cuda_error(); // FT_CHECK(0); } else { - check_cuda_error(cudaMemcpy(workspace, weight.kernel, sizeof(T) * input_dim * output_dim, cudaMemcpyDefault)); + check_cuda_error( + cudaMemcpyAsync(workspace, weight.kernel, sizeof(T) * input_dim * output_dim, cudaMemcpyDefault, st)); } MatrixLayout src{ @@ -583,35 +521,42 @@ static void convert_fp(LlamaDenseWeight& weight, bool is_fused_moe, void* wor dst.pack = pack_b; if (pack_b) { - FT_CHECK(Convert(workspace, src, weight.kernel, dst, nullptr) == 0); + FT_CHECK(Convert(workspace, src, weight.kernel, dst, st) == 0); sync_check_cuda_error(); // FT_CHECK(0); } else { - check_cuda_error(cudaMemcpy(weight.kernel, workspace, sizeof(T) * input_dim * output_dim, cudaMemcpyDefault)); + check_cuda_error( + cudaMemcpyAsync(weight.kernel, workspace, sizeof(T) * input_dim * output_dim, cudaMemcpyDefault, st)); } weight.k_desc = dst; } template -static void convert(LlamaDenseWeight& weight, bool is_fused_moe, void* workspace, size_t size, bool use_simt) +static void +convert(LlamaDenseWeight& weight, bool is_fused_moe, void* workspace, size_t size, bool use_simt, cudaStream_t st) { if (weight.type == WeightType::kINT4) { if constexpr (std::is_same_v) { - convert_u4(weight, is_fused_moe, workspace, size, use_simt); + convert_u4(weight, is_fused_moe, workspace, size, use_simt, st); } else { FT_CHECK(0); } } else { - convert_fp(weight, is_fused_moe, workspace, size, use_simt); + convert_fp(weight, is_fused_moe, workspace, size, use_simt, st); } } template -void interleave(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, void* workspace, size_t size) +void interleave(LlamaDenseWeight& c, + LlamaDenseWeight& a, + LlamaDenseWeight& b, + void* workspace, + size_t size, + cudaStream_t st) { FT_CHECK(c.input_dims == a.input_dims); FT_CHECK(c.input_dims == b.input_dims); @@ -628,18 +573,18 @@ void interleave(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight const auto sentinel = tmp_c + c.output_dims * c.input_dims; FT_CHECK(sentinel <= (uint8_t*)workspace + size); - extend_to_u8(tmp_a, (const uint4_t*)a.kernel, a.output_dims * a.input_dims); - extend_to_u8(tmp_b, (const uint4_t*)b.kernel, b.output_dims * b.input_dims); + extend_to_u8(tmp_a, (const uint4_t*)a.kernel, a.output_dims * a.input_dims, st); + extend_to_u8(tmp_b, (const uint4_t*)b.kernel, b.output_dims * b.input_dims, st); - interleave_output_dims(tmp_c, tmp_a, tmp_b, a.output_dims, a.input_dims, 0); + interleave_output_dims(tmp_c, tmp_a, tmp_b, a.output_dims, a.input_dims, st); - compact_to_u4((uint4_t*)c.kernel, tmp_c, c.output_dims * c.input_dims); + compact_to_u4((uint4_t*)c.kernel, tmp_c, c.output_dims * c.input_dims, st); - interleave_output_dims(c.scales, a.scales, b.scales, a.output_dims, a.input_dims / a.group_size, 0); - interleave_output_dims(c.zeros, a.zeros, b.zeros, a.output_dims, a.input_dims / a.group_size, 0); + interleave_output_dims(c.scales, a.scales, b.scales, a.output_dims, a.input_dims / a.group_size, st); + interleave_output_dims(c.zeros, a.zeros, b.zeros, a.output_dims, a.input_dims / a.group_size, st); } else { - interleave_output_dims((T*)c.kernel, (const T*)a.kernel, (const T*)b.kernel, a.output_dims, a.input_dims, 0); + interleave_output_dims((T*)c.kernel, (const T*)a.kernel, (const T*)b.kernel, a.output_dims, a.input_dims, st); } // Check at function level @@ -647,7 +592,7 @@ void interleave(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight } template -void chunk(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, void*, size_t) +void chunk(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& b, void*, size_t, cudaStream_t st) { FT_CHECK(c.input_dims == a.input_dims); FT_CHECK(c.input_dims == b.input_dims); @@ -656,9 +601,11 @@ void chunk(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& FT_CHECK(c.group_size == a.group_size); FT_CHECK(c.group_size == b.group_size); - auto _chunks = [](auto c, auto a, auto b, int height, int width) { - check_cuda_error(cudaMemcpy2D((char*)c + 0x000, width * 2, a, width, width, height, cudaMemcpyDefault)); - check_cuda_error(cudaMemcpy2D((char*)c + width, width * 2, b, width, width, height, cudaMemcpyDefault)); + auto _chunks = [&](auto c, auto a, auto b, int height, int width) { + check_cuda_error( + cudaMemcpy2DAsync((char*)c + 0x000, width * 2, a, width, width, height, cudaMemcpyDefault, st)); + check_cuda_error( + cudaMemcpy2DAsync((char*)c + width, width * 2, b, width, width, height, cudaMemcpyDefault, st)); }; if (c.type == WeightType::kINT4) { @@ -675,37 +622,37 @@ void chunk(LlamaDenseWeight& c, LlamaDenseWeight& a, LlamaDenseWeight& } template -void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cudaDeviceProp& prop) +void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cudaDeviceProp& prop, cudaStream_t st) { const bool is_16xx = is_16xx_series(prop.name); - convert(self_attn_weights.qkv, false, workspace, size, is_16xx); - convert(self_attn_weights.output, false, workspace, size, is_16xx); + convert(self_attn_weights.qkv, false, workspace, size, is_16xx, st); + convert(self_attn_weights.output, false, workspace, size, is_16xx, st); auto process_ffn = [&](LlamaFfnWeight& ffn, bool is_fused_moe) { if (fused_up_and_gate_) { auto& fused_up_and_gate = ffn.fused_gating_intermediate; - mallocWeights(fused_up_and_gate, false); + fused_up_and_gate.malloc(st); if (ffn.is_fused_silu) { - interleave(fused_up_and_gate, ffn.gating, ffn.intermediate, workspace, size); + interleave(fused_up_and_gate, ffn.gating, ffn.intermediate, workspace, size, st); } else { - chunk(fused_up_and_gate, ffn.gating, ffn.intermediate, workspace, size); + chunk(fused_up_and_gate, ffn.gating, ffn.intermediate, workspace, size, st); } - convert(ffn.fused_gating_intermediate, is_fused_moe, workspace, size, is_16xx); + convert(ffn.fused_gating_intermediate, is_fused_moe, workspace, size, is_16xx, st); - freeWeights(ffn.gating); - freeWeights(ffn.intermediate); + ffn.gating.free(st); + ffn.intermediate.free(st); } else { - convert(ffn.gating, is_fused_moe, workspace, size, is_16xx); - convert(ffn.intermediate, is_fused_moe, workspace, size, is_16xx); + convert(ffn.gating, is_fused_moe, workspace, size, is_16xx, st); + convert(ffn.intermediate, is_fused_moe, workspace, size, is_16xx, st); } - convert(ffn.output, is_fused_moe, workspace, size, is_16xx); + convert(ffn.output, is_fused_moe, workspace, size, is_16xx, st); }; if (inter_size_) { @@ -722,7 +669,7 @@ void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cud for (auto& e : moe_weights.experts) { - process_ffn(e, moe_weights.method); + process_ffn(e, moe_weights.method == MoeParam::kFused); const auto& fused = e.fused_gating_intermediate; const auto& output = e.output; @@ -743,12 +690,12 @@ void LlamaDecoderLayerWeight::prepare(void* workspace, size_t size, const cud auto& output = moe_weights.block.output; // TODO: free these ptrs - fused.kernel = gemm::make_blocked_ptrs(fused_ptrs, nullptr); - output.kernel = gemm::make_blocked_ptrs(output_ptrs, nullptr); + fused.kernel = gemm::make_blocked_ptrs(fused_ptrs, st); + output.kernel = gemm::make_blocked_ptrs(output_ptrs, st); if (!fused_param_ptrs.empty()) { - fused.scales_zeros = (T*)gemm::make_blocked_ptrs(fused_param_ptrs, nullptr); - output.scales_zeros = (T*)gemm::make_blocked_ptrs(output_param_ptrs, nullptr); + fused.scales_zeros = (T*)gemm::make_blocked_ptrs(fused_param_ptrs, st); + output.scales_zeros = (T*)gemm::make_blocked_ptrs(output_param_ptrs, st); } fused.k_desc.ld = output.k_desc.ld = 0; diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index f68a103dd..9b204ed0d 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -30,19 +30,14 @@ template struct LlamaDecoderLayerWeight { public: LlamaDecoderLayerWeight() = delete; - LlamaDecoderLayerWeight(int layer_idx, - size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t hidden_units, - size_t inter_size, - WeightType weight_type, - int group_size, - LoraParam lora_param, - bool attn_bias, - MoeParam moe_param, - size_t tensor_para_size, - size_t tensor_para_rank); + + LlamaDecoderLayerWeight(int layer_id, + const ModelParam& model, + const LoraParam& lora_param, + const MoeParam& moe_param, + size_t tp_size, + size_t tp_rank); + ~LlamaDecoderLayerWeight(); LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete; @@ -51,17 +46,21 @@ struct LlamaDecoderLayerWeight { TensorMap getParams(std::string prefix); - void prepare(void* workspace, size_t size, const cudaDeviceProp& prop); + void prepare(void* workspace, size_t size, const cudaDeviceProp& prop, cudaStream_t st); size_t workspace_size() const noexcept; - void mallocWeights(LlamaDenseWeight& weights, bool bias); + void malloc(cudaStream_t st); + + void free(cudaStream_t st); + + T* self_attn_norm_weights{}; + T* ffn_norm_weights{}; - T* self_attn_norm_weights{}; - T* ffn_norm_weights{}; LlamaAttentionWeight self_attn_weights{}; - LlamaFfnWeight ffn_weights{}; - MoeFfnWeight moe_weights{}; + + LlamaFfnWeight ffn_weights{}; + MoeFfnWeight moe_weights{}; private: size_t head_num_; @@ -76,8 +75,6 @@ struct LlamaDecoderLayerWeight { size_t tensor_para_rank_; bool is_maintain_buffer_ = false; bool fused_up_and_gate_; - - void mallocWeights(); }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 169fb53bc..944781bf5 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -20,64 +20,14 @@ #pragma once #include "src/turbomind/kernels/gemm/types.h" +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/weight_type.h" #include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/memory_utils.h" #include namespace turbomind { -enum class WeightType : int -{ - kFP32, - kFP16, - kFP8, // not supported yet - kBF16, - kINT8, - kINT4 -}; - -template -constexpr WeightType get_default_weight_type() -{ - if constexpr (std::is_same_v) { - return WeightType::kFP16; - } - else if constexpr (std::is_same_v) { - return WeightType::kBF16; - } - else if constexpr (std::is_same_v) { - return WeightType::kFP32; - } - else { - static_assert(sizeof(T) != sizeof(T), "not implemented"); - return {}; - } -} - -inline size_t getBitSize(WeightType type) -{ - switch (type) { - case WeightType::kFP32: - return 32; - case WeightType::kFP16: - return 16; - case WeightType::kFP8: - return 8; - case WeightType::kBF16: - return 16; - case WeightType::kINT8: - return 8; - case WeightType::kINT4: - return 4; - } - return 0; -} - -enum class LoraPolicy : int -{ - kNull, - kPlora, -}; - inline LoraPolicy getLoraPolicy(const std::string& policy) { if (policy == "plora") { @@ -96,20 +46,31 @@ struct LoraWeight { template struct LlamaDenseWeight { - size_t input_dims; - size_t output_dims; - void* kernel; + size_t input_dims = 0; + size_t output_dims = 0; + WeightType type; // uninitialized + void* kernel = nullptr; + T* bias = nullptr; + T* scales = nullptr; + T* zeros = nullptr; + T* scales_zeros = nullptr; + int group_size = 1; + LoraWeight lora; - WeightType type; - T* bias; - T* scales; - T* zeros; - T* scales_zeros; - int group_size; gemm::MatrixLayout k_desc; gemm::MatrixLayout q_desc; + LlamaDenseWeight(): type{}, lora{}, k_desc{}, q_desc{} {} + + LlamaDenseWeight(size_t input_dim, size_t output_dim, WeightType type, int group_size): LlamaDenseWeight{} + { + this->input_dims = input_dim; + this->output_dims = output_dim; + this->type = type; + this->group_size = group_size; + } + size_t kernel_size() const noexcept { return getBitSize(type) * input_dims * output_dims / 8; @@ -129,12 +90,121 @@ struct LlamaDenseWeight { { return {sizeof(T) * input_dims * lora.r, sizeof(T) * lora.r * output_dims}; } + + void malloc(cudaStream_t st, bool with_bias = false) + { + if (with_bias) { + deviceMalloc((T**)&bias, output_dims, st); + } + const size_t bit_size = getBitSize(type); + if (bit_size >= 16) { // fp16, fp32 + deviceMalloc((T**)&kernel, input_dims * output_dims, st); + } + else { // int8, int4 + const int factor = sizeof(float) * 8 / bit_size; + FT_CHECK(input_dims % factor == 0); + deviceMalloc((int**)&kernel, input_dims * output_dims / factor, st); + deviceMalloc((T**)&scales, input_dims / group_size * output_dims, st); + deviceMalloc((T**)&zeros, input_dims / group_size * output_dims, st); + } + + if (lora.r > 0) { + deviceMalloc((T**)&lora.a, input_dims * lora.r, st); + deviceMalloc((T**)&lora.b, lora.r * output_dims, st); + } + } + + void free(cudaStream_t st) + { + deviceFree(kernel, st); + deviceFree(bias, st); + deviceFree(scales, st); + deviceFree(zeros, st); + deviceFree(lora.a, st); + deviceFree(lora.b, st); + } }; template struct LlamaAttentionWeight { + + LlamaAttentionWeight() = default; + + LlamaAttentionWeight(size_t hidden_dim, + size_t head_dim, + size_t head_num, + size_t kv_head_num, + MLAParam mla, + bool bias, + size_t tp, + WeightType weight_type, + int group_size) + { + this->bias = bias; + if (mla.kv_lora_rank == 0) { + qkv = {hidden_dim, (head_num + 2 * kv_head_num) * head_dim / tp, weight_type, group_size}; + } + else { + const int qk_nope_dim = head_dim - mla.qk_rope_dim; + if (mla.q_lora_rank) { + q_a_proj = {hidden_dim, mla.q_lora_rank, weight_type, group_size}; + q_b_proj = {mla.q_lora_rank, head_num * head_dim / tp, weight_type, group_size}; + } + else { + q_proj = {hidden_dim, head_num * head_dim / tp, weight_type, group_size}; + } + kv_a_proj = {hidden_dim, mla.kv_lora_rank + mla.qk_rope_dim, weight_type, group_size}; + kv_b_proj = {mla.kv_lora_rank, head_num * (qk_nope_dim + mla.v_head_dim) / tp, weight_type, group_size}; + } + output = {(head_num * head_dim) / tp, hidden_dim, weight_type, group_size}; + } + + void malloc(cudaStream_t st) + { + if (qkv.output_dims) { + qkv.malloc(st, bias); + } + else { + if (q_proj.output_dims) { + q_proj.malloc(st); + } + else { + q_a_proj.malloc(st); + q_b_proj.malloc(st); + deviceMalloc((T**)&q_a_layernorm, q_b_proj.input_dims, st); + } + kv_a_proj.malloc(st); + kv_b_proj.malloc(st); + deviceMalloc((T**)&kv_a_layernorm, kv_b_proj.input_dims, st); + } + output.malloc(st, bias); + } + + void free(cudaStream_t st) + { + qkv.free(st); + q_proj.free(st); + q_a_proj.free(st); + q_b_proj.free(st); + kv_a_proj.free(st); + kv_b_proj.free(st); + output.free(st); + deviceFree(q_a_layernorm, st); + deviceFree(kv_a_layernorm, st); + } + LlamaDenseWeight qkv; LlamaDenseWeight output; + bool bias{}; + + LlamaDenseWeight q_proj; + LlamaDenseWeight q_a_proj; + LlamaDenseWeight q_b_proj; + LlamaDenseWeight kv_a_proj; + LlamaDenseWeight kv_b_proj; + + T* q_a_layernorm{}; + T* kv_a_layernorm{}; }; template @@ -172,6 +242,21 @@ struct LlamaFfnWeight { output.group_size = group_size; } + void malloc(cudaStream_t st) + { + gating.malloc(st); + intermediate.malloc(st); + output.malloc(st); + } + + void free(cudaStream_t st) + { + gating.free(st); + intermediate.free(st); + output.free(st); + fused_gating_intermediate.free(st); + } + LlamaDenseWeight gating; LlamaDenseWeight intermediate; LlamaDenseWeight output; @@ -186,23 +271,27 @@ struct MoeFfnWeight { MoeFfnWeight() = default; - MoeFfnWeight(size_t hidden_dim, - int inter_size, - int expert_num, - int method, - bool has_shared_gate, - size_t tp, - WeightType weight_type, - int group_size, - bool fuse_silu_act) + MoeFfnWeight(int layer_id, + const MoeParam& param, + size_t hidden_dim, + WeightType weight_type, + int group_size, + size_t tp, + bool fuse_silu_act) { - // printf("%d %d %d\n", (int)hidden_dim, (int)inter_size, (int)expert_num); + if (param.expert_num.size() <= layer_id) { + return; + } + + const int expert_num = param.expert_num[layer_id]; if (expert_num == 0) { return; } + // printf("%d %d %d\n", (int)hidden_dim, (int)param.inter_size, (int)expert_num); + gate.input_dims = hidden_dim; gate.output_dims = expert_num; gate.type = get_default_weight_type(); @@ -210,15 +299,15 @@ struct MoeFfnWeight { experts.resize(expert_num); - this->method = method; - fuse_silu_act = fuse_silu_act && method; + method = param.method; + fuse_silu_act = fuse_silu_act && method == MoeParam::kFused; for (auto& e : experts) { // inter size is divided by tp in `FfnWeight` - e = LlamaFfnWeight{hidden_dim, (size_t)inter_size, tp, weight_type, group_size, fuse_silu_act}; + e = LlamaFfnWeight{hidden_dim, (size_t)param.inter_size, tp, weight_type, group_size, fuse_silu_act}; } - if (has_shared_gate) { + if (param.shared_gate) { shared_gate.input_dims = hidden_dim; shared_gate.output_dims = 1; shared_gate.type = get_default_weight_type(); @@ -229,14 +318,36 @@ struct MoeFfnWeight { } } + void malloc(cudaStream_t st) + { + gate.malloc(st); + if (shared_gate.output_dims) { + shared_gate.malloc(st); + } + for (auto& e : experts) { + e.malloc(st); + } + } + + void free(cudaStream_t st) + { + gate.free(st); + shared_gate.free(st); + for (auto& e : experts) { + e.free(st); + } + block.free(st); + } + LlamaDenseWeight gate; std::vector> experts; LlamaDenseWeight shared_gate; + // reference into `experts` LlamaFfnWeight block; - int method{}; + MoeParam::Method method{}; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc index 8cce20720..907467341 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.cc +++ b/src/turbomind/models/llama/LlamaFfnLayer.cc @@ -27,21 +27,20 @@ namespace turbomind { template -void LlamaFfnLayer::allocateBuffer(size_t token_num, - int inter_size, - const LlamaDenseWeight* gating, - const LlamaDenseWeight* inter) +void LlamaFfnLayer::allocateBuffer( + size_t token_num, int inter_size, size_t inter_buf_factor, size_t gating_lora_r, size_t inter_lora_r) { const size_t sz = token_num * inter_size; - const size_t sz_gate = token_num * gating->lora.r; - const size_t sz_inter = token_num * inter->lora.r; + const size_t sz_gate = token_num * gating_lora_r; + const size_t sz_inter = token_num * inter_lora_r; - gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * (sz * 2 + sz_gate + sz_inter), false); - inter_buf_ = gating_buf_ + sz; + gating_buf_ = + (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * (sz * inter_buf_factor + sz_gate + sz_inter), false); + inter_buf_ = gating_buf_ + sz; // gate & inter is not fused when lora is enabled - if (gating->lora.r) { + if (gating_lora_r) { inter_buf_ += sz_gate; } @@ -93,12 +92,16 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, const int layer_id = input_tensors->getVal("layer_id"); const int inter_size = weights->inter_size; - allocateBuffer(token_num, inter_size, &weights->gating, &weights->intermediate); + const bool is_fused_silu = weights->fused_gating_intermediate.kernel && weights->is_fused_silu; + + allocateBuffer(token_num, inter_size, is_fused_silu ? 1 : 2, weights->gating.lora.r, weights->intermediate.lora.r); const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); int* lora_mask = input_tensors->at("lora_mask", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); + const bool all_reduce = input_tensors->getVal("all_reduce", false); + if (weights->fused_gating_intermediate.kernel) { NvtxScope scope("fused_silu_ffn"); @@ -145,7 +148,8 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, count_and_fix(ffn_output_data, token_num * weights->output.output_dims, Concat("w2", layer_id), 3); - if (all_reduce_ && tensor_para_.world_size_ > 1) { + if (all_reduce && tensor_para_.world_size_ > 1) { + // std::cout << "ffn all reduce " << layer_id << "\n"; NcclGuard nccl_guard(tensor_para_, stream_); ftNcclAllReduceSum(ffn_output_data, ffn_output_data, token_num * hidden_units_, tensor_para_, stream_); sync_check_cuda_error(); diff --git a/src/turbomind/models/llama/LlamaFfnLayer.h b/src/turbomind/models/llama/LlamaFfnLayer.h index 2daca2cc9..a72a24701 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.h +++ b/src/turbomind/models/llama/LlamaFfnLayer.h @@ -30,13 +30,12 @@ namespace turbomind { template class LlamaFfnLayer { public: - LlamaFfnLayer(const ModelParam& model, const NcclParam& tp, const Context& ctx, bool all_reduce): + LlamaFfnLayer(const ModelParam& model, const NcclParam& tp, const Context& ctx): hidden_units_(model.hidden_units), tensor_para_(tp), stream_(ctx.stream), linear_(ctx.linear.get()), - allocator_(ctx.allocator.get()), - all_reduce_(all_reduce) + allocator_(ctx.allocator.get()) { } @@ -48,7 +47,8 @@ class LlamaFfnLayer { void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaFfnWeight* weights); private: - void allocateBuffer(size_t token_num, int inter_size, const LlamaDenseWeight*, const LlamaDenseWeight*); + void allocateBuffer( + size_t token_num, int inter_size, size_t inter_buf_factor, size_t gating_lora_r, size_t inter_lora_r); void freeBuffer(); @@ -59,7 +59,6 @@ class LlamaFfnLayer { cudaStream_t const stream_; LlamaLinear* const linear_; IAllocator* const allocator_; - const bool all_reduce_; bool is_free_buffer_after_forward_{}; T* gating_buf_{}; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 3d50910ad..05b22deed 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -72,7 +72,6 @@ LlamaV2::LlamaV2(const ModelParam& model, lora_param_(lora), head_num_(model.head_num), size_per_head_(model.head_dim), - inter_size_(model.inter_size), hidden_units_(model.hidden_units), layer_num_(model.layer_num), vocab_size_(model.vocab_size), diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 6321d09d7..658282f5e 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -113,7 +113,6 @@ class LlamaV2 { const size_t head_num_; const size_t size_per_head_; const size_t hidden_units_; - const size_t inter_size_; const size_t layer_num_; const size_t vocab_size_; const size_t vocab_size_padded_; diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 9d62042d6..bcee15097 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -20,36 +20,24 @@ #include "src/turbomind/models/llama/LlamaWeight.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/memory_utils.h" #include namespace turbomind { template -LlamaWeight::LlamaWeight(size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t hidden_units, - size_t inter_size, - size_t vocab_size, - size_t embedding_size, - size_t num_layer, - bool attn_bias, - WeightType weight_type, - int group_size, - LoraParam lora_param, - MoeParam moe_param, - size_t tensor_para_size, - size_t tensor_para_rank): - hidden_units_(hidden_units), - inter_size_(inter_size), - vocab_size_(vocab_size), - vocab_size_padded_(vocab_size), - embedding_size_(embedding_size), - num_layer_(num_layer), - weight_type_(weight_type), - tensor_para_size_(tensor_para_size), - tensor_para_rank_(tensor_para_rank) +LlamaWeight::LlamaWeight( + const ModelParam& model, const LoraParam& lora_param, const MoeParam& moe_param, size_t tp_size, size_t tp_rank): + hidden_units_(model.hidden_units), + inter_size_(model.inter_size), + vocab_size_(model.vocab_size), + vocab_size_padded_(model.vocab_size), + embedding_size_(model.embedding_size), + num_layer_(model.layer_num), + weight_type_(model.weight_type), + tensor_para_size_(tp_size), + tensor_para_rank_(tp_rank) { if (vocab_size_padded_ % tensor_para_size_ != 0) { vocab_size_padded_ = (vocab_size_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; @@ -61,49 +49,42 @@ LlamaWeight::LlamaWeight(size_t head_num, } FT_CHECK(hidden_units_ % tensor_para_size_ == 0); + check_cuda_error(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); + decoder_layer_weights.reserve(num_layer_); for (unsigned l = 0; l < num_layer_; ++l) { - decoder_layer_weights.push_back(new LlamaDecoderLayerWeight(l, - head_num, - kv_head_num, - size_per_head, - hidden_units_, - inter_size_, - weight_type_, - group_size, - lora_param, - attn_bias, - moe_param, - tensor_para_size_, - tensor_para_rank_)); + decoder_layer_weights.emplace_back( + new LlamaDecoderLayerWeight(l, model, lora_param, moe_param, tp_size, tp_rank)); + decoder_layer_weights.back()->malloc(stream_); } - mallocWeights(); + FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); + deviceMalloc((T**)&pre_decoder_embedding_table, embedding_size_ * hidden_units_ / tensor_para_size_, stream_); + deviceMalloc((T**)&output_norm_weight, hidden_units_, stream_); + deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_, stream_); + + // Wait for allocations + check_cuda_error(cudaStreamSynchronize(stream_)); } template LlamaWeight::~LlamaWeight() { - cudaFree((void*)pre_decoder_embedding_table); - cudaFree((void*)output_norm_weight); - cudaFree((void*)post_decoder_embedding_kernel); - - pre_decoder_embedding_table = nullptr; - output_norm_weight = nullptr; - post_decoder_embedding_kernel = nullptr; + deviceFree(pre_decoder_embedding_table, stream_); + deviceFree(output_norm_weight, stream_); + deviceFree(post_decoder_embedding_kernel, stream_); for (auto& p : decoder_layer_weights) { + p->free(stream_); delete p; } -} -template -void LlamaWeight::mallocWeights() -{ - FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); - deviceMalloc((T**)&pre_decoder_embedding_table, embedding_size_ * hidden_units_ / tensor_para_size_); - deviceMalloc((T**)&output_norm_weight, hidden_units_); - deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_); + decoder_layer_weights.clear(); + + // Wait for deallocations + check_cuda_error(cudaStreamSynchronize(stream_)); + check_cuda_error(cudaStreamDestroy(stream_)); + stream_ = {}; } template @@ -179,13 +160,19 @@ void LlamaWeight::prepare(const cudaDeviceProp& prop) TM_LOG_INFO("[LlamaWeight::prepare] workspace size: %d\n", workspace_size); + // Wait for the weights to be filled externally + check_cuda_error(cudaDeviceSynchronize()); + if (workspace_size) { - deviceMalloc((char**)&workspace, workspace_size); + deviceMalloc((char**)&workspace, workspace_size, stream_); } for (auto& layer : decoder_layer_weights) { - layer->prepare(workspace, workspace_size, prop); + layer->prepare(workspace, workspace_size, prop, stream_); } - deviceFree(workspace); + + deviceFree(workspace, stream_); + + check_cuda_error(cudaStreamSynchronize(stream_)); } #ifdef ENABLE_FP32 diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index c30e75356..629cd5612 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -22,28 +22,18 @@ #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/llama_params.h" -#include "src/turbomind/utils/memory_utils.h" namespace turbomind { template struct LlamaWeight { LlamaWeight() = default; - LlamaWeight(size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t hidden_units, - size_t inter_size, - size_t vocab_size, - size_t embedding_size, - size_t num_layer, - bool attn_bias, - WeightType weight_type, - int group_size, - LoraParam lora_param, - MoeParam moe_param, - size_t tensor_para_size, - size_t tensor_para_rank); + + LlamaWeight(const ModelParam& model_param, + const LoraParam& lora_param, + const MoeParam& moe_param, + size_t tp_size, + size_t tp_rank); ~LlamaWeight(); @@ -57,15 +47,13 @@ struct LlamaWeight { void prepare(const cudaDeviceProp& prop); std::vector*> decoder_layer_weights; - const T* pre_decoder_embedding_table{}; - const T* output_norm_weight{}; - const T* post_decoder_embedding_kernel{}; -private: - void mallocWeights(); + T* pre_decoder_embedding_table{}; + T* output_norm_weight{}; + T* post_decoder_embedding_kernel{}; +private: size_t hidden_units_; - size_t inter_size_; size_t vocab_size_; size_t vocab_size_padded_; size_t embedding_size_; @@ -73,6 +61,10 @@ struct LlamaWeight { WeightType weight_type_; size_t tensor_para_size_; size_t tensor_para_rank_; + + std::vector inter_size_; + + cudaStream_t stream_; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_gemm.cc b/src/turbomind/models/llama/llama_gemm.cc index 62952cd71..f9a0191e4 100644 --- a/src/turbomind/models/llama/llama_gemm.cc +++ b/src/turbomind/models/llama/llama_gemm.cc @@ -84,7 +84,7 @@ int main(int argc, char* argv[]) return -1; } else { - ft::deviceMalloc(reinterpret_cast(&gemm_test_buf), buf_size_in_byte, false); + ft::deviceMalloc(reinterpret_cast(&gemm_test_buf), buf_size_in_byte, nullptr, false); } if (0) {} diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index 3b01dee60..aaade1a51 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -154,7 +154,7 @@ template struct TempBuffer { TempBuffer(size_t size) { - deviceMalloc(&data, size, false); + cudaMalloc(&data, size); } T* data; }; diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index e6b9d690a..0a505b11a 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -2,28 +2,41 @@ #pragma once -#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include #include #include #include +#include "src/turbomind/models/llama/weight_type.h" + namespace turbomind { +struct MLAParam { + size_t q_lora_rank; + size_t kv_lora_rank; + size_t qk_rope_dim; + size_t v_head_dim; +}; + struct ModelParam { - size_t head_num; - size_t head_dim; - size_t kv_head_num; - size_t hidden_units; - size_t layer_num; - size_t inter_size; - size_t vocab_size; - size_t embedding_size; - float norm_eps; - int quant_policy; - // - int start_id; - int end_id; + size_t head_num; + size_t head_dim; + size_t kv_head_num; + size_t hidden_units; + size_t layer_num; + size_t vocab_size; + size_t embedding_size; + float norm_eps; + int quant_policy; + bool attn_bias; + WeightType weight_type; + int group_size; + int start_id; + int end_id; + MLAParam mla; + int tune_layer_num; + + std::vector inter_size; }; struct MoeParam { @@ -32,17 +45,25 @@ struct MoeParam { kNaive, kFused } method; - int expert_num; - int experts_per_token; - int inter_size; - bool norm_topk; - bool shared_gate; + + int experts_per_token; + int inter_size; + bool norm_topk_prob; + bool shared_gate; + float routed_scale; + + int topk_group; + std::string topk_method; + int n_group; + + std::vector expert_num; }; struct AttentionParam { int rotary_embedding_dim; float rotary_embedding_base; int max_position_embeddings; + float softmax_scale; std::string rope_scaling_type; int original_max_position_embeddings; float rope_scaling_factor; @@ -74,6 +95,12 @@ struct EngineParam { int max_prefill_iters; }; +enum class LoraPolicy : int +{ + kNull, + kPlora, +}; + struct LoraParam { int r; float scale; diff --git a/src/turbomind/models/llama/llama_utils.cu b/src/turbomind/models/llama/llama_utils.cu index 925c6b883..eaa450ae2 100644 --- a/src/turbomind/models/llama/llama_utils.cu +++ b/src/turbomind/models/llama/llama_utils.cu @@ -1,47 +1,25 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include "src/turbomind/kernels/reduce_kernel_utils.cuh" -#include "src/turbomind/models/llama/llama_utils.h" -#include "src/turbomind/utils/cuda_utils.h" #include #include #include #include +#include +#include + #include #include #include #include #include -#include + +#include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/utils/cuda_utils.h" namespace turbomind { CmpMode compare_mode = kCmpRead; - -template -struct abs_diff_t { - using type = T; -}; - -template<> -struct abs_diff_t { - using type = float; -}; - -template<> -struct abs_diff_t<__nv_bfloat16> { - using type = float; -}; - -template -struct abs_diff: public thrust::unary_function, typename abs_diff_t::type> { - __host__ __device__ float operator()(thrust::tuple x) const - { - using R = typename abs_diff_t::type; - auto r = R(thrust::get<0>(x)) - R(thrust::get<1>(x)); - return r < R(0) ? -r : r; - } -}; +// CmpMode compare_mode = kCmpWrite; template void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream) @@ -63,10 +41,8 @@ void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream) template void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream) { - // wait for b - check_cuda_error(cudaStreamSynchronize(stream)); // read a from file - thrust::host_vector h_a(size); + std::vector h_a(size); { const auto filename = "tmp/" + key + ".cmp"; std::ifstream ifs(filename, std::ios::binary); @@ -85,15 +61,30 @@ void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream) } ifs.read((char*)h_a.data(), sizeof(T) * h_a.size()); } - // copy a to device - thrust::device_vector a = h_a; - // create abs(a - b) iterator - thrust::device_ptr dev_ptr(ptr); - auto zip_iter = thrust::make_zip_iterator(thrust::make_tuple(a.begin(), dev_ptr)); - auto transform_iter = thrust::make_transform_iterator(zip_iter, abs_diff{}); - // sum(abs(a - b)) - auto asum = thrust::reduce(thrust::device, transform_iter, transform_iter + size); - std::cerr << key << ": " << asum << " " << asum / size << "\n"; + std::vector h_b(size); + check_cuda_error(cudaMemcpyAsync(h_b.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream)); + check_cuda_error(cudaStreamSynchronize(stream)); + + using Tacc = std::conditional_t, int64_t, float>; + constexpr Tacc eps = std::is_integral_v ? 1 : 1e-8f; + + Tacc asum{}; + Tacc rsum{}; + Tacc amean{}; + for (size_t i = 0; i < size; ++i) { + Tacc x = (Tacc)h_b[i]; + Tacc r = (Tacc)h_a[i]; + Tacc abs_diff = std::abs(x - r); + Tacc rel_diff = abs_diff / std::max(std::max(std::abs(r), std::abs(x)), eps); + asum += abs_diff; + rsum += rel_diff; + amean += std::abs(r); + } + + std::cerr << key << ": " << amean / size << " " << asum << " " << asum / size << " " << rsum / size << "\n"; + + check_cuda_error(cudaMemcpyAsync(ptr, h_a.data(), sizeof(T) * h_a.size(), cudaMemcpyDefault, stream)); + check_cuda_error(cudaStreamSynchronize(stream)); } template diff --git a/src/turbomind/models/llama/mla_utils.cu b/src/turbomind/models/llama/mla_utils.cu new file mode 100644 index 000000000..2f9e786f2 --- /dev/null +++ b/src/turbomind/models/llama/mla_utils.cu @@ -0,0 +1,93 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "src/turbomind/kernels/core/array_ops.h" + +namespace turbomind { + +template +__global__ void mla_copy_qkv_kernel(T* qkv, + const T* q, // [h, head_dim] + const T* kv_a, // [kv_lora_rank, rope_dim] + const T* kv_b, // [h, nope_dim + v_head_dim] + int head_num, + int head_dim, + int nope_dim, + int rope_dim, + int kv_lora_rank, + int v_head_dim) +{ + const int type = blockIdx.y; + + const int64_t ti = blockIdx.x; + const int di = threadIdx.x; + + const int kv_b_dim = nope_dim + v_head_dim; + + // for (int hi = threadIdx.y; hi < head_num; hi += blockDim.y) { + const int hi = threadIdx.y; + Array data{}; + if (type == 0) { // Q + if (di * vec_size < rope_dim) { + Ldg(data, &q[ti * head_num * head_dim + hi * head_dim + nope_dim + di * vec_size]); + } + else { + Ldg(data, &q[ti * head_num * head_dim + hi * head_dim + di * vec_size - rope_dim]); + } + } + else if (type == 1) { // K + if (di * vec_size < rope_dim) { + Ldg(data, &kv_a[ti * (kv_lora_rank + rope_dim) + kv_lora_rank + di * vec_size]); + } + else { + Ldg(data, &kv_b[ti * head_num * kv_b_dim + hi * kv_b_dim + di * vec_size - rope_dim]); + } + } + else { // V + if (di * vec_size < v_head_dim) { + Ldg(data, &kv_b[ti * head_num * kv_b_dim + hi * kv_b_dim + nope_dim + di * vec_size]); + } + } + const int stride = 3 * head_num * head_dim; + Store(&qkv[ti * stride + type * head_num * head_dim + hi * head_dim + di * vec_size], data); + // } +} + +template +void invokeMLACopyQKV(T* qkv, + const T* q, + const T* kv_a, + const T* kv_b, + int token_num, + int head_num, + int nope_dim, + int rope_dim, + int kv_lora_rank, + int v_head_dim, + cudaStream_t stream) +{ + constexpr int vec_size = 16 / sizeof(T); + const int head_dim = nope_dim + rope_dim; + + dim3 block(head_dim / vec_size, head_num); + // make sure block size <= 1024 + while (block.x * block.y > 1024) { + block.y /= 2; + } + const dim3 grid(token_num, 3); + + mla_copy_qkv_kernel<<>>( + qkv, q, kv_a, kv_b, head_num, head_dim, nope_dim, rope_dim, kv_lora_rank, v_head_dim); +} + +template void invokeMLACopyQKV(uint16_t* qkv, + const uint16_t* q, + const uint16_t* kv_a, + const uint16_t* kv_b, + int token_num, + int head_num, + int nope_dim, + int rope_dim, + int kv_lora_rank, + int v_head_dim, + cudaStream_t stream); + +} // namespace turbomind diff --git a/src/turbomind/models/llama/mla_utils.h b/src/turbomind/models/llama/mla_utils.h new file mode 100644 index 000000000..bc06a352f --- /dev/null +++ b/src/turbomind/models/llama/mla_utils.h @@ -0,0 +1,57 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#pragma once + +#include +#include + +#include "src/turbomind/utils/cuda_utils.h" + +namespace turbomind { + +template +void invokeMLACopyQKV(T* qkv, + const T* q, + const T* kv_a, + const T* kv_b, + int token_num, + int head_num, + int nope_dim, + int rope_dim, + int kv_lora_rank, + int v_head_dim, + cudaStream_t stream); + +template +void dispatchMLACopyQKV(T* qkv, + const T* q, + const T* kv_a, + const T* kv_b, + int token_num, + int head_num, + int nope_dim, + int rope_dim, + int kv_lora_rank, + int v_head_dim, + cudaStream_t stream) +{ + auto invoke = [&](auto x) { + using type = decltype(x); + invokeMLACopyQKV((type*)qkv, + (const type*)q, + (const type*)kv_a, + (const type*)kv_b, + token_num, + head_num, + nope_dim, + rope_dim, + kv_lora_rank, + v_head_dim, + stream); + }; + if constexpr (sizeof(T) == 2) { + return invoke(uint16_t{}); + } + FT_CHECK(0); +} + +} // namespace turbomind diff --git a/src/turbomind/models/llama/moe_ffn_layer.cc b/src/turbomind/models/llama/moe_ffn_layer.cc index 1ad76839d..390d14754 100644 --- a/src/turbomind/models/llama/moe_ffn_layer.cc +++ b/src/turbomind/models/llama/moe_ffn_layer.cc @@ -11,22 +11,21 @@ #include "src/turbomind/utils/nvtx_utils.h" #include "src/turbomind/utils/string_utils.h" #include -#include #include namespace turbomind { template -void MoeFfnLayer::AllocateBuffer(size_t tokens, size_t padded) +void MoeFfnLayer::AllocateBuffer(size_t tokens, size_t padded, size_t expert_num, size_t inter_buf_factor) { char* base = 0; auto allocate = [&](void* base) { Monotonic alloc{base}; alloc(&inout_buf_, tokens * param_.experts_per_token * hidden_dim_); - alloc(&inter_buf_, tokens * param_.experts_per_token * inter_size_ * 2); - alloc(&logits_, tokens * param_.expert_num); - alloc(&masks_, param_.expert_num * padded); + alloc(&inter_buf_, tokens * param_.experts_per_token * inter_size_ * inter_buf_factor); + alloc(&logits_, tokens * expert_num); + alloc(&masks_, expert_num * padded); alloc(&f2n_, param_.experts_per_token * tokens); alloc(&en2f_, param_.experts_per_token * tokens); alloc(&scales_, param_.experts_per_token * tokens); @@ -80,18 +79,42 @@ void MoeFfnLayer::gate(float* logits, const T* input, int tokens, const Llama template void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id, const MoeFfnWeight& moe) { - const size_t padded = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize; + const size_t padded = (tokens + kMoeGateVecSize - 1) / kMoeGateVecSize * kMoeGateVecSize; + const int expert_num = moe.experts.size(); - AllocateBuffer(tokens, padded); + FT_CHECK(expert_num); + + const size_t inter_buf_factor = [&] { + if (param_.method == MoeParam::kNaive) { + return 0; // managed by ffn + } + else if (moe.block.is_fused_silu) { + return 1; + } + else { + return 2; + } + }(); + + AllocateBuffer(tokens, padded, expert_num, inter_buf_factor); gate(logits_, input, tokens, moe.gate); sync_check_cuda_error(); - check_cuda_error(cudaMemsetAsync(accum_, 0, sizeof(int) * param_.expert_num * kMoeGateMaxTiles, stream_)); - sync_check_cuda_error(); + // if (tensor_para_.rank_ == 0) { + // Compare(logits_, tokens * expert_num, Concat("logit", layer_id), compare_mode, stream_); + // } + + check_cuda_error(cudaMemsetAsync(accum_, 0, sizeof(int) * expert_num * kMoeGateMaxTiles, stream_)); + check_cuda_error(cudaMemsetAsync(masks_, -1, sizeof(int8_t) * expert_num * padded, stream_)); // dump_logits(tokens, layer_id); + if (param_.topk_method == "group_limited_greedy") { + invokeMaskMoeTopKGroups(logits_, tokens, expert_num, expert_num / param_.n_group, param_.topk_group, stream_); + sync_check_cuda_error(); + } + /// TODO: fix illegal memory access even if NaN are present in logits invokeMoeGate_V2(f2n_, en2f_, @@ -102,25 +125,26 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id logits_, tokens, padded, - param_.expert_num, + expert_num, param_.experts_per_token, - param_.norm_topk, + param_.norm_topk_prob, + param_.routed_scale, stream_); sync_check_cuda_error(); if (isTuning()) { std::mt19937 g; - const auto expert_ids = SampleUniform(tokens, param_.expert_num, param_.experts_per_token, g); - std::vector cnt(param_.expert_num); + const auto expert_ids = SampleUniform(tokens, expert_num, param_.experts_per_token, g); + std::vector cnt(expert_num); for (const auto& x : expert_ids) { ++cnt[x]; } h_offsets_[0] = 0; - for (int i = 0; i < param_.expert_num; ++i) { + for (int i = 0; i < expert_num; ++i) { h_offsets_[i + 1] = h_offsets_[i] + cnt[i]; } check_cuda_error( - cudaMemcpyAsync(offsets_, h_offsets_, sizeof(int) * (param_.expert_num + 1), cudaMemcpyDefault, stream_)); + cudaMemcpyAsync(offsets_, h_offsets_, sizeof(int) * (expert_num + 1), cudaMemcpyDefault, stream_)); } if (param_.method == MoeParam::kNaive) { @@ -129,15 +153,15 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id sync_check_cuda_error(); check_cuda_error( - cudaMemcpyAsync(h_offsets_, offsets_, sizeof(int) * (param_.expert_num + 1), cudaMemcpyDefault, stream_)); + cudaMemcpyAsync(h_offsets_, offsets_, sizeof(int) * (expert_num + 1), cudaMemcpyDefault, stream_)); check_cuda_error(cudaStreamSynchronize(stream_)); - if (h_offsets_[param_.expert_num] != tokens * param_.experts_per_token) { - FT_CHECK_WITH_INFO(0, fmtstr("%d vs %d", h_offsets_[param_.expert_num], tokens * param_.experts_per_token)); + if (h_offsets_[expert_num] != tokens * param_.experts_per_token) { + FT_CHECK_WITH_INFO(0, fmtstr("%d vs %d", h_offsets_[expert_num], tokens * param_.experts_per_token)); } - for (int i = 0; i < param_.expert_num; ++i) { + for (int i = 0; i < expert_num; ++i) { FT_CHECK(moe.experts[i].is_fused_silu == false); @@ -153,7 +177,7 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id } } else { - context_->set_offsets(offsets_); + context_->update(expert_num, param_.experts_per_token, offsets_); auto& block = moe.block; @@ -217,7 +241,7 @@ void MoeFfnLayer::forward(T* output, const T* input, int tokens, int layer_id } template -void MoeFfnLayer::reduce(T* output, int tokens, const MoeFfnWeight& moe) +void MoeFfnLayer::reduce(T* output, int tokens, float output_scale, int layer_id, const MoeFfnWeight& moe) { invokeMoeReduce(output, inout_buf_, @@ -227,19 +251,21 @@ void MoeFfnLayer::reduce(T* output, int tokens, const MoeFfnWeight& moe) tokens, param_.experts_per_token, hidden_dim_, + output_scale, stream_); sync_check_cuda_error(); if (tensor_para_.world_size_ > 1) { + // std::cout << "moe all reduce " << layer_id << "\n"; ftNcclAllReduceSum(output, output, tokens * hidden_dim_, tensor_para_, stream_); sync_check_cuda_error(); } } template -void MoeFfnLayer::dump_logits(int token_num, int layer_id) +void MoeFfnLayer::dump_logits(int token_num, int layer_id, int expert_num) { - std::vector logits(token_num * param_.expert_num); + std::vector logits(token_num * expert_num); check_cuda_error( cudaMemcpyAsync(logits.data(), logits_, sizeof(float) * logits.size(), cudaMemcpyDefault, stream_)); check_cuda_error(cudaStreamSynchronize(stream_)); @@ -247,7 +273,7 @@ void MoeFfnLayer::dump_logits(int token_num, int layer_id) auto ptr = logits.data(); std::cout << "layer_id: " << layer_id << std::endl; for (int i = 0; i < token_num; ++i) { - for (int e = 0; e < param_.expert_num; ++e) { + for (int e = 0; e < expert_num; ++e) { std::cout << *ptr++ << " "; } std::cout << std::endl; diff --git a/src/turbomind/models/llama/moe_ffn_layer.h b/src/turbomind/models/llama/moe_ffn_layer.h index 0f1713f7b..74c62d004 100644 --- a/src/turbomind/models/llama/moe_ffn_layer.h +++ b/src/turbomind/models/llama/moe_ffn_layer.h @@ -9,6 +9,7 @@ #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/nccl_utils.h" +#include namespace turbomind { @@ -26,23 +27,24 @@ class MoeFfnLayer { linear_(ctx.linear.get()), allocator_(ctx.allocator.get()) { - model.inter_size = param.inter_size; + FT_CHECK(!param.expert_num.empty()); + const int max_expert_num = *std::max_element(param.expert_num.begin(), param.expert_num.end()); if (param_.method == MoeParam::kFused) { context_ = std::make_unique( - param.expert_num, param.experts_per_token, ctx.cuda_device_prop, stream_); + max_expert_num, param.experts_per_token, ctx.cuda_device_prop, stream_); } else { - expert_ffn_ = std::make_unique>(model, tp, ctx, false); + expert_ffn_ = std::make_unique>(model, tp, ctx); } - h_offsets_ = (int*)allocator_->malloc(sizeof(int) * (param_.expert_num + 1), false, true); + h_offsets_ = (int*)allocator_->malloc(sizeof(int) * (max_expert_num + 1), false, true); - offsets_ = (int*)allocator_->malloc(sizeof(int) * (param_.expert_num + 1)); - accum_ = (int*)allocator_->malloc(sizeof(int) * param_.expert_num * kMoeGateMaxTiles); + offsets_ = (int*)allocator_->malloc(sizeof(int) * (max_expert_num + 1)); + accum_ = (int*)allocator_->malloc(sizeof(int) * max_expert_num * kMoeGateMaxTiles); } - void AllocateBuffer(size_t tokens, size_t padded); + void AllocateBuffer(size_t tokens, size_t padded, size_t expert_num, size_t inter_buf_factor); void FreeBuffer(); @@ -53,11 +55,11 @@ class MoeFfnLayer { void forward(T* output, const T* input, int tokens, int layer_id, const MoeFfnWeight& moe); - void reduce(T* output, int tokens, const MoeFfnWeight& moe); + void reduce(T* output, int tokens, float output_scale, int layer_id, const MoeFfnWeight& moe); void gate(float* logits, const T* input, int tokens, const LlamaDenseWeight& weight); - void dump_logits(int token_num, int layer_id); + void dump_logits(int token_num, int layer_id, int expert_num); private: const size_t inter_size_; diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 2f99b0c2c..7a6eddc4b 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -19,21 +19,24 @@ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -#include "src/turbomind/models/llama/unified_attention_layer.h" +#include +#include + #include "src/turbomind/kernels/attention/attention.h" #include "src/turbomind/kernels/attention/decoding.h" #include "src/turbomind/kernels/attention/kv_cache_utils_v2.h" +#include "src/turbomind/kernels/norm/rms_norm.h" #include "src/turbomind/macro.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/models/llama/mla_utils.h" +#include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/anomaly_handler.h" #include "src/turbomind/utils/cuda_utils.h" -#include "src/turbomind/utils/debug_utils.h" #include "src/turbomind/utils/logger.h" -#include -#include +#include "src/turbomind/utils/memory_utils.h" namespace turbomind { @@ -72,17 +75,14 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, } template -void UnifiedAttentionLayer::allocateBuffer(size_t q_count, - size_t k_count, - size_t batch_size, - const WeightType* weights) +void UnifiedAttentionLayer::allocateBuffer(size_t q_count, size_t k_count, size_t batch_size, size_t qkv_lora_rank) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; - if (weights->qkv.lora.r) { - size_t sz = sizeof(T) * q_count * (local_q_kv_head_num * size_per_head_ + weights->qkv.lora.r); + if (qkv_lora_rank) { + size_t sz = sizeof(T) * q_count * (local_q_kv_head_num * size_per_head_ + qkv_lora_rank); qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sz, false); } else { @@ -198,28 +198,38 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa allocateBuffer(token_num, // shared h_cu_k_len[batch_size] - h_cu_k_len[dc_batch_size], // prefill batch_size, - weights); + weights->qkv.lora.r); // [L, 2, H, s, D] const size_t layer_offset = layer_id * 2 * local_kv_head_num_ * param_.cache_block_seq_len * size_per_head_; - static int count = 0; + // static int count = 0; - // if (layer_id == 0 && count == 0) { - // Compare(attention_input, token_num * weights->qkv.input_dims, "qkv_input", compare_mode, stream_); + // if (tensor_para_.rank_ == 0) { + // Compare(attention_input, token_num * hidden_units_, Concat("qkv_input", layer_id), compare_mode, stream_); // } int* lora_mask = inputs->at("lora_mask", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); - ////////////////////////////////////////////// - /// qkv gemm - // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] - linear_->forward(qkv_buf_, attention_input, token_num, weights->qkv, LlamaLinear::kGemm, lora_mask); - sync_check_cuda_error(); + + if (weights->qkv.output_dims) { + ////////////////////////////////////////////// + /// qkv gemm + // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] + linear_->forward(qkv_buf_, attention_input, token_num, weights->qkv, LlamaLinear::kGemm, lora_mask); + sync_check_cuda_error(); + } + else { + forward_mla(attention_input, token_num, *weights); + } + + // std::cerr << layer_id << " " << count << " " << tensor_para_.rank_ << "\n"; count_and_fix(qkv_buf_, token_num * weights->qkv.output_dims, Concat("qkv", layer_id), 3); - // if (layer_id == 0 && count == 0) { - // Compare(qkv_buf_, token_num * weights->qkv.output_dims, "qkv_buf", compare_mode, stream_); + // std::cerr << "token num: " << token_num << "\n"; + + // if (layer_id == 0 && count == 0 && tensor_para_.rank_ == 0) { + // Compare(qkv_buf_, token_num * (3 * local_head_num_ * size_per_head_), "qkv_buf", CMP_MODE, stream_); // } if constexpr (0) { @@ -290,8 +300,15 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa params.num_heads = local_head_num_; params.num_kv_heads = local_kv_head_num_; params.size_per_head = size_per_head_; + // MSVC does not have M_LOG2E - params.inv_sqrt_dh = (float)std::log2(expf(1.)) / std::sqrt((float)params.size_per_head); + params.inv_sqrt_dh = (float)std::log2(expf(1.)); + if (param_.softmax_scale) { // model predefined softmax scale + params.inv_sqrt_dh *= param_.softmax_scale; + } + else { // default value + params.inv_sqrt_dh /= std::sqrt((float)params.size_per_head); + } params.rotary_embedding_dim = param_.rotary_embedding_dim; params.rotary_embedding_base = param_.rotary_embedding_base; @@ -324,8 +341,9 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa }; float low, high; find_correction_range(param_.beta_fast, param_.beta_slow, low, high); + // https://github.com/huggingface/transformers/blob/6c3f168b36882f0beebaa9121eafa1928ba29633/src/transformers/modeling_rope_utils.py#L216 if (low == high) { - high += 0.01f; + high += 0.001f; } params.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; params.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; @@ -415,8 +433,6 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa linear_->forward(attention_out, qkv_buf_3_, token_num, weights->output, LlamaLinear::kGemm, lora_mask); sync_check_cuda_error(); - // ++count; - count_and_fix(attention_out, token_num * weights->output.output_dims, Concat("wo", layer_id), 3); if (tensor_para_.world_size_ > 1) { @@ -425,10 +441,94 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa sync_check_cuda_error(); } + // if (tensor_para_.rank_ == 0) { + // Compare(attention_out, token_num * hidden_units_, Concat("attn_out", layer_id), compare_mode, stream_); + // // dump(qkv_buf_3_, num_token * weights->output.input_dims, stream_, "qkv_buf_3"); + // } + if (is_free_buffer_after_forward_ == true) { freeBuffer(); } sync_check_cuda_error(); + + // ++count; +} + +template +void UnifiedAttentionLayer::forward_mla(const T* inputs, int token_num, const WeightType& w) +{ + const int q_lora_rank = w.q_a_proj.output_dims; + const int kv_lora_rank = w.kv_b_proj.input_dims; + const int qk_rope_dim = w.kv_a_proj.output_dims - kv_lora_rank; + const int qk_nope_dim = std::max(w.q_b_proj.output_dims, w.q_proj.output_dims) / local_head_num_ - qk_rope_dim; + const int v_head_dim = w.kv_b_proj.output_dims / local_head_num_ - qk_nope_dim; + + T* q{}; + + if (w.q_proj.kernel) { + deviceMalloc((T**)&q, (size_t)token_num * w.q_proj.output_dims, stream_); + linear_->forward(q, inputs, token_num, w.q_proj); + sync_check_cuda_error(); + } + else { + T* q_a{}; + deviceMalloc((T**)&q_a, (size_t)token_num * q_lora_rank, stream_); + + linear_->forward(q_a, inputs, token_num, w.q_a_proj); + sync_check_cuda_error(); + + invokeRMSNorm(q_a, + q_lora_rank, + q_a, + q_lora_rank, + w.q_a_layernorm, + q_lora_rank, + token_num, + model_param_.norm_eps, + stream_); + sync_check_cuda_error(); + + deviceMalloc((T**)&q, (size_t)token_num * w.q_b_proj.output_dims, stream_); + linear_->forward(q, q_a, token_num, w.q_b_proj); + sync_check_cuda_error(); + + deviceFree(q_a, stream_); + } + + T* kv_a{}; + const int kv_a_dim = w.kv_a_proj.output_dims; + deviceMalloc((T**)&kv_a, (size_t)token_num * kv_a_dim, stream_); + + linear_->forward(kv_a, inputs, token_num, w.kv_a_proj); + sync_check_cuda_error(); + + invokeRMSNorm( + kv_a, kv_a_dim, kv_a, kv_a_dim, w.kv_a_layernorm, kv_lora_rank, token_num, model_param_.norm_eps, stream_); + sync_check_cuda_error(); + + T* kv_b{}; + deviceMalloc((T**)&kv_b, (size_t)token_num * w.kv_b_proj.output_dims, stream_); + sync_check_cuda_error(); + + linear_->forward(kv_b, {kv_a, kv_a_dim}, token_num, w.kv_b_proj); + sync_check_cuda_error(); + + dispatchMLACopyQKV(qkv_buf_, + q, + kv_a, + kv_b, + token_num, + local_head_num_, + qk_nope_dim, + qk_rope_dim, + kv_lora_rank, + v_head_dim, + stream_); + sync_check_cuda_error(); + + deviceFree(q, stream_); + deviceFree(kv_a, stream_); + deviceFree(kv_b, stream_); } #ifdef ENABLE_FP32 diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index da0c0e6fc..7d331b0e4 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -42,7 +42,7 @@ class UnifiedAttentionLayer { static constexpr int kMaxWorkspaceTokens = 4096; void freeBuffer(); - void allocateBuffer(size_t q_count, size_t k_count, size_t batch_size, const WeightType* weights); + void allocateBuffer(size_t q_count, size_t k_count, size_t batch_size, size_t qkv_lora_rank); void allocateWorkspace(); void freeWorkspace(); @@ -70,7 +70,7 @@ class UnifiedAttentionLayer { const NcclParam& tp, const Context& context); - void forward(TensorMap* outputs, const TensorMap* inputs, const LlamaAttentionWeight* weights); + void forward(TensorMap* outputs, const TensorMap* inputs, const WeightType* weights); void prefill(T* output, T* tmp_kv_buffer, @@ -107,6 +107,9 @@ class UnifiedAttentionLayer { int max_split_k, const WeightType* weights); +private: + void forward_mla(const T* inputs, int token_num, const WeightType& weights); + private: const size_t head_num_; const size_t kv_head_num_; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 28e8b5f64..ec0e75b7e 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -1,13 +1,17 @@ -#include "src/turbomind/models/llama/unified_decoder.h" + +#include + +#include "src/turbomind/kernels/norm/rms_norm.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/moe_ffn_layer.h" #include "src/turbomind/models/llama/unified_attention_layer.h" +#include "src/turbomind/models/llama/unified_decoder.h" +#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/anomaly_handler.h" #include "src/turbomind/utils/cuda_utils.h" -#include namespace turbomind { @@ -23,17 +27,19 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, rmsnorm_eps_(model.norm_eps), stream_(ctx.stream), allocator_(ctx.allocator.get()), - dtype_(getTensorType()) + tp_(tp), + dtype_(getTensorType()), + tune_layer_num_(model.tune_layer_num) { attn_layer_ = std::make_unique>(model, attn, lora, tp, ctx); - if (moe.expert_num) { + if (std::accumulate(moe.expert_num.begin(), moe.expert_num.end(), 0LL)) { moe_ffn_layer_ = std::make_unique>(model, moe, tp, ctx); } - if (model.inter_size) { - ffn_layer_ = std::make_unique>(model, tp, ctx, !moe_ffn_layer_); + if (std::accumulate(model.inter_size.begin(), model.inter_size.end(), 0LL)) { + ffn_layer_ = std::make_unique>(model, tp, ctx); } check_cuda_error(cudaEventCreateWithFlags(&ev_h_cu_x_, cudaEventDisableTiming)); @@ -65,13 +71,13 @@ void UnifiedDecoder::freeBuffer() } template -void UnifiedDecoder::forwardSelfAttn(T* attn_io, - TensorMap* _outputs, - const TensorMap* _inputs, - size_t token_num, - size_t batch_size, - int layer_id, - const LlamaAttentionWeight* weight) +void UnifiedDecoder::forwardSelfAttn(T* attn_io, + TensorMap* _outputs, + const TensorMap* _inputs, + size_t token_num, + size_t batch_size, + int layer_id, + const WeightType* weight) { TensorMap inputs(*_inputs); inputs.insert("input_query", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); @@ -84,7 +90,7 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, TensorMap outputs(*_outputs); outputs.insert("hidden_features", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); - attn_layer_->forward(&outputs, &inputs, weight); + attn_layer_->forward(&outputs, &inputs, &weight->self_attn_weights); } template @@ -141,19 +147,15 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con const int pf_offset = dc_batch_size; - // Compare(decoder_input_output, token_num * hidden_units_, "decoder_input", kCmpRead, stream_); - - // printf("%d %f\n", (int)token_num, rmsnorm_eps_); - ///////////////////////////////////////////// /// RMSNorm - invokeRootMeanSquareNorm(decoder_output, - decoder_input_output, - weights->at(0)->self_attn_norm_weights, - rmsnorm_eps_, - token_num, - hidden_units_, - stream_); + invokeRMSNorm(decoder_output, + decoder_input_output, + weights->at(0)->self_attn_norm_weights, + hidden_units_, + token_num, + rmsnorm_eps_, + stream_); sync_check_cuda_error(); count_and_fix(decoder_output, token_num * hidden_units_, Concat("norm0", 0), 2); @@ -161,12 +163,10 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con for (size_t layer = 0; layer < layer_num_; ++layer) { /// TODO: do not skip the layers when they are heterogeneous - if (isTuning() && layer != 0) { + if (isTuning() && layer >= tune_layer_num_) { continue; } - // Compare(decoder_output, token_num * hidden_units_, "attn_input", kCmpRead, stream_); - ///////////////////////////////////////////// /// self-attention forwardSelfAttn(decoder_output, // @@ -175,18 +175,18 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con token_num, batch_size, layer, - &weights->at(layer)->self_attn_weights); + weights->at(layer)); count_and_fix(decoder_output, token_num * hidden_units_, Concat("attn_block", layer), 2); - invokeFusedAddBiasResidualRMSNorm(decoder_input_output, - decoder_output, - weights->at(layer)->self_attn_weights.output.bias, - weights->at(layer)->ffn_norm_weights, - rmsnorm_eps_, - token_num, - hidden_units_, - stream_); + invokeBiasResidualRMSNorm(decoder_input_output, + decoder_output, + weights->at(layer)->ffn_norm_weights, + weights->at(layer)->self_attn_weights.output.bias, + hidden_units_, + token_num, + rmsnorm_eps_, + stream_); sync_check_cuda_error(); count_and_fix(decoder_input_output, token_num * hidden_units_, Concat("residual0", layer), 2); @@ -195,14 +195,17 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con //////////////////////////////////////////// /// feed-forward network - if (!weights->at(layer)->moe_weights.experts.empty()) { + const bool is_moe = !weights->at(layer)->moe_weights.experts.empty(); + if (is_moe) { moe_ffn_layer_->forward(nullptr, decoder_output, token_num, layer, weights->at(layer)->moe_weights); } - if (ffn_layer_) { - int layer_id = layer; // int is needed + if (weights->at(layer)->ffn_weights.output.kernel) { + int layer_id = layer; // int is needed + bool all_reduce = !is_moe; TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}, - {"layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id}}}; + {"layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id}}, + {"all_reduce", {MEMORY_CPU, TYPE_BOOL, {1}, &all_reduce}}}; TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; if (inputs->isExist("lora_mask")) { ffn_inputs.insert({"lora_mask", inputs->at("lora_mask")}); @@ -210,8 +213,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &weights->at(layer)->ffn_weights); } - if (!weights->at(layer)->moe_weights.experts.empty()) { - moe_ffn_layer_->reduce(decoder_output, token_num, weights->at(layer)->moe_weights); + if (is_moe) { + moe_ffn_layer_->reduce(decoder_output, token_num, (bool)ffn_layer_, layer, weights->at(layer)->moe_weights); } count_and_fix(decoder_output, token_num * hidden_units_, Concat("ffn_block", layer), 2); diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index f13b4ba84..e08567136 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -22,7 +22,9 @@ class UnifiedDecoder { const float rmsnorm_eps_; cudaStream_t const stream_; IAllocator* const allocator_; + const NcclParam tp_; const DataType dtype_; + const int tune_layer_num_; bool is_free_buffer_after_forward_{}; int* cu_q_len_{}; @@ -39,13 +41,13 @@ class UnifiedDecoder { using WeightType = LlamaDecoderLayerWeight; - void forwardSelfAttn(T* attn_io, - TensorMap* _outputs, - const TensorMap* _inputs, - size_t token_num, - size_t batch_size, - int layer_id, - const LlamaAttentionWeight* weight); + void forwardSelfAttn(T* attn_io, + TensorMap* _outputs, + const TensorMap* _inputs, + size_t token_num, + size_t batch_size, + int layer_id, + const WeightType* weight); public: UnifiedDecoder(const ModelParam& model, diff --git a/src/turbomind/models/llama/weight_type.h b/src/turbomind/models/llama/weight_type.h new file mode 100644 index 000000000..bc2f49a08 --- /dev/null +++ b/src/turbomind/models/llama/weight_type.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +namespace turbomind { + +enum class WeightType : int +{ + kFP32, + kFP16, + kFP8, // not supported yet + kBF16, + kINT8, + kINT4 +}; + +template +constexpr WeightType get_default_weight_type() +{ + if constexpr (std::is_same_v) { + return WeightType::kFP16; + } + else if constexpr (std::is_same_v) { + return WeightType::kBF16; + } + else if constexpr (std::is_same_v) { + return WeightType::kFP32; + } + else { + static_assert(sizeof(T) != sizeof(T), "not implemented"); + return {}; + } +} + +inline size_t getBitSize(WeightType type) +{ + switch (type) { + case WeightType::kFP32: + return 32; + case WeightType::kFP16: + return 16; + case WeightType::kFP8: + return 8; + case WeightType::kBF16: + return 16; + case WeightType::kINT8: + return 8; + case WeightType::kINT4: + return 4; + } + return 0; +} + +} // namespace turbomind diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 4eb34249f..5a344d954 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -215,6 +215,51 @@ DLTensor GetDLTensor(py::object obj) return dlmt->dl_tensor; } +static void safe_memcpy(void* dst, const void* src, size_t size) +{ + cudaPointerAttributes dat{}; + cudaPointerAttributes sat{}; + ft::check_cuda_error(cudaPointerGetAttributes(&dat, dst)); + ft::check_cuda_error(cudaPointerGetAttributes(&sat, src)); + try { + if (dat.devicePointer && sat.devicePointer) { + // Both can be accessed from current context + ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault)); + } + else if (dat.type == cudaMemoryTypeDevice && sat.type == cudaMemoryTypeDevice) { + if (dat.device != sat.device) { + // On different devices, try peer memcpy + ft::check_cuda_error(cudaMemcpyPeer(dst, dat.device, src, sat.device, size)); + } + else { + // Same device, switch to the device first (this is unlikely) + ft::CudaDeviceGuard guard(dat.device); + ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault)); + } + } + else { + // Unknown case, give it a try anyway + ft::check_cuda_error(cudaMemcpy(dst, src, size, cudaMemcpyDefault)); + } + } + catch (...) { + int device_id{-1}; + cudaGetDevice(&device_id); + TM_LOG_ERROR("cudaMemcpy failed: dst=(%d, %d, %p, %p), src=(%d, %d, %p, %p), size=%s, device=%d", + (int)dat.type, + dat.device, + dat.devicePointer, + dat.hostPointer, + (int)sat.type, + sat.device, + sat.devicePointer, + sat.hostPointer, + std::to_string(size).c_str(), + device_id); + throw; + } +} + PYBIND11_MODULE(_turbomind, m) { // nccl param @@ -293,8 +338,7 @@ PYBIND11_MODULE(_turbomind, m) std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies()); auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; ft::FT_CHECK(self->shape.size() == 1 && num_bytes == self->shape[0]); - cudaMemcpy( - const_cast(self->data), const_cast(src->data), num_bytes, cudaMemcpyDefault); + safe_memcpy(const_cast(self->data), src->data, num_bytes); break; } default: diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 2deca4638..1c7c5eb46 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -256,22 +256,30 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_param_.kv_head_num = model_reader["kv_head_num"].as(0); model_param_.hidden_units = model_reader["hidden_units"].as(); model_param_.layer_num = model_reader["num_layer"].as(); - model_param_.inter_size = model_reader["inter_size"].as(); model_param_.vocab_size = model_reader["vocab_size"].as(); model_param_.embedding_size = model_reader["embedding_size"].as(); model_param_.norm_eps = model_reader["norm_eps"].as(); model_param_.start_id = model_reader["start_id"].as(); model_param_.end_id = model_reader["end_id"].as(); + model_param_.tune_layer_num = model_reader["tune_layer_num"].as(1); + model_param_.mla.q_lora_rank = model_reader["q_lora_rank"].as(); + model_param_.mla.kv_lora_rank = model_reader["kv_lora_rank"].as(); + model_param_.mla.qk_rope_dim = model_reader["qk_rope_dim"].as(); + model_param_.mla.v_head_dim = model_reader["v_head_dim"].as(); attn_param_.cache_block_seq_len = attention_reader["cache_block_seq_len"].as(0); model_param_.quant_policy = engine_reader["quant_policy"].as(0); - + YAML::Node inter_size = model_reader["inter_size"]; + for (auto it = inter_size.begin(); it != inter_size.end(); ++it) { + model_param_.inter_size.push_back(it->as()); + } // Only weight classes need these - attn_bias_ = model_reader["attn_bias"].as(0); - group_size_ = model_reader["group_size"].as(0); + model_param_.attn_bias = model_reader["attn_bias"].as(0); + model_param_.group_size = model_reader["group_size"].as(0); // rotary embedding parameters attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as(); attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as(10000.0f); + attn_param_.softmax_scale = attention_reader["softmax_scale"].as(0); attn_param_.attention_factor = attention_reader["attention_factor"].as(-1.f); attn_param_.beta_fast = attention_reader["beta_fast"].as(32.f); attn_param_.beta_slow = attention_reader["beta_slow"].as(1.f); @@ -297,19 +305,27 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, engine_param_.num_tokens_per_iter = engine_reader["num_tokens_per_iter"].as(0); engine_param_.max_prefill_iters = engine_reader["max_prefill_iters"].as(1); - lora_param_.policy = ft::getLoraPolicy(reader["lora_config"]["lora_policy"].as("")); - lora_param_.r = lora_reader["lora_r"].as(0); - lora_param_.scale = lora_reader["lora_scale"].as(0); - lora_param_.max_wo_r = lora_reader["lora_max_wo_r"].as(0); - lora_param_.rank_pattern = getLoraPattern(lora_reader["lora_rank_pattern"].as(""), + lora_param_.policy = ft::getLoraPolicy(reader["lora_config"]["lora_policy"].as("")); + lora_param_.r = lora_reader["lora_r"].as(0); + lora_param_.scale = lora_reader["lora_scale"].as(0); + lora_param_.max_wo_r = lora_reader["lora_max_wo_r"].as(0); + lora_param_.rank_pattern = getLoraPattern(lora_reader["lora_rank_pattern"].as(""), [](const std::string& s) { return std::stoi(s); }); - lora_param_.scale_pattern = getLoraPattern(lora_reader["lora_scale_pattern"].as(""), + lora_param_.scale_pattern = getLoraPattern(lora_reader["lora_scale_pattern"].as(""), [](const std::string& s) { return std::stof(s); }); - moe_param_.expert_num = model_reader["expert_num"].as(0); + moe_param_.experts_per_token = model_reader["experts_per_token"].as(0); moe_param_.inter_size = model_reader["expert_inter_size"].as(0); - moe_param_.shared_gate = model_reader["moe_shared_gate"].as(0); - moe_param_.norm_topk = model_reader["moe_norm_topk"].as(false); + moe_param_.shared_gate = model_reader["moe_shared_gate"].as(); + moe_param_.norm_topk_prob = model_reader["norm_topk_prob"].as(); + moe_param_.routed_scale = model_reader["routed_scale"].as(1.f); + moe_param_.topk_group = model_reader["topk_group"].as(1); + moe_param_.topk_method = model_reader["topk_method"].as("greedy"); + moe_param_.n_group = model_reader["moe_group_num"].as(1); + YAML::Node expert_num = model_reader["expert_num"]; + for (auto it = expert_num.begin(); it != expert_num.end(); ++it) { + moe_param_.expert_num.push_back(it->as()); + } handleMissingParams(); @@ -321,19 +337,19 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, const std::string weight_type_str = model_reader["weight_type"].as(); if (weight_type_str == "fp16" || weight_type_str == "float16") { - weight_type_ = ft::WeightType::kFP16; + model_param_.weight_type = ft::WeightType::kFP16; } else if (weight_type_str == "bf16" || weight_type_str == "bfloat16") { - weight_type_ = ft::WeightType::kBF16; + model_param_.weight_type = ft::WeightType::kBF16; } else if (weight_type_str == "fp32") { - weight_type_ = ft::WeightType::kFP32; + model_param_.weight_type = ft::WeightType::kFP32; } else if (weight_type_str == "int8") { - weight_type_ = ft::WeightType::kINT8; + model_param_.weight_type = ft::WeightType::kINT8; } else if (weight_type_str == "int4") { - weight_type_ = ft::WeightType::kINT4; + model_param_.weight_type = ft::WeightType::kINT4; } else { std::cout << "[ERROR] Unsupported weight type: '" << weight_type_str << "'\n"; @@ -418,21 +434,8 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) const int tensor_para_rank = rank % tensor_para_size_; const int pipeline_para_rank = rank / tensor_para_size_; ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0); - weights_[device_id] = std::make_shared>(model_param_.head_num, - model_param_.kv_head_num, - model_param_.head_dim, - model_param_.hidden_units, - model_param_.inter_size, - model_param_.vocab_size, - model_param_.embedding_size, - model_param_.layer_num, - attn_bias_, - weight_type_, - group_size_, - lora_param_, - moe_param_, - tensor_para_size_, - tensor_para_rank); + weights_[device_id] = std::make_shared>( + model_param_, lora_param_, moe_param_, tensor_para_size_, tensor_para_rank); // model inited with model_dir if (model_dir_ != "") { weights_[device_id]->loadModel(model_dir_); @@ -488,9 +491,11 @@ std::string LlamaTritonModel::toString() std::stringstream ss; ss << "Model: " // << "\nhead_num: " << model_param_.head_num << "\nkv_head_num: " << model_param_.kv_head_num - << "\nsize_per_head: " << model_param_.head_dim << "\ninter_size: " << model_param_.inter_size + << "\nsize_per_head: " + << model_param_.head_dim + // << "\ninter_size: " << model_param_.inter_size << "\nnum_layer: " << model_param_.layer_num << "\nvocab_size: " << model_param_.vocab_size - << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << engine_param_.max_batch_size + << "\nattn_bias: " << model_param_.attn_bias << "\nmax_batch_size: " << engine_param_.max_batch_size << "\nmax_prefill_token_num: " << engine_param_.max_prefill_token_num << "\nmax_context_token_num: " << engine_param_.max_context_token_num << "\nnum_tokens_per_iter: " << engine_param_.num_tokens_per_iter @@ -501,8 +506,9 @@ std::string LlamaTritonModel::toString() << "\nenable_prefix_caching: " << engine_param_.enable_prefix_caching << "\nstart_id: " << model_param_.start_id << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ - << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << model_param_.quant_policy - << "\ngroup_size: " << group_size_ << "\nexpert_num: " << moe_param_.expert_num + << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << model_param_.quant_policy << "\ngroup_size: " + << model_param_.group_size + // << "\nexpert_num: " << moe_param_.expert_num << "\nexpert_per_token: " << moe_param_.experts_per_token << "\nmoe_method: " << moe_param_.method << std::endl; return ss.str(); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index 19a143e72..a6c1b862a 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -91,9 +91,6 @@ struct LlamaTritonModel: public AbstractTransformerModel { ft::EngineParam engine_param_; size_t tensor_para_size_; size_t pipeline_para_size_; - ft::WeightType weight_type_; - bool attn_bias_; - int group_size_; std::shared_ptr shared_state_; // Weights & engine instances for the ranks diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index bdcb9bfc4..88c299c3d 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -281,7 +281,8 @@ class Allocator: public IAllocator { pointer_mapping_.erase(address); } else { - TM_LOG_WARNING("pointer_mapping_ does not have information of ptr at %p.", address); + FT_CHECK_WITH_INFO(0, + fmtstr("pointer_mapping_ does not have information of ptr at %p.", address).c_str()); } } *ptr = nullptr; diff --git a/src/turbomind/utils/cuda_utils.h b/src/turbomind/utils/cuda_utils.h index 2148fcc16..8311e6eb9 100644 --- a/src/turbomind/utils/cuda_utils.h +++ b/src/turbomind/utils/cuda_utils.h @@ -483,5 +483,24 @@ void compareTwoTensor( bool is_16xx_series(const char* name); +class CudaDeviceGuard { +public: + CudaDeviceGuard(int device) + { + cudaGetDevice(&last_device_id_); + if (device != last_device_id_) { + cudaSetDevice(device); + } + } + + ~CudaDeviceGuard() + { + cudaSetDevice(last_device_id_); + } + +private: + int last_device_id_{-1}; +}; + /* ************************** end of common utils ************************** */ } // namespace turbomind diff --git a/src/turbomind/utils/memory_utils.cu b/src/turbomind/utils/memory_utils.cu index f8bfb8efe..e9a79ea5a 100644 --- a/src/turbomind/utils/memory_utils.cu +++ b/src/turbomind/utils/memory_utils.cu @@ -26,77 +26,71 @@ namespace turbomind { template -void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) +void deviceMalloc(T** ptr, size_t size, cudaStream_t st, bool is_random_initialize) { - FT_CHECK_WITH_INFO(size >= ((size_t)0), "Ask deviceMalloc size " + std::to_string(size) + "< 0 is invalid."); - check_cuda_error(cudaMalloc((void**)(ptr), sizeof(T) * size)); + check_cuda_error(cudaMallocAsync((void**)(ptr), sizeof(T) * size, st)); if (is_random_initialize) { - cudaRandomUniform(*ptr, size); + cudaRandomUniform(*ptr, size, st); } } -template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(float** ptr, size_t size, cudaStream_t, bool is_random_initialize); +template void deviceMalloc(half** ptr, size_t size, cudaStream_t, bool is_random_initialize); #ifdef ENABLE_BF16 -template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(__nv_bfloat16** ptr, size_t size, cudaStream_t, bool is_random_initialize); #endif -template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(uint16_t** ptr, size_t size, cudaStream_t, bool is_random_initialize); +template void deviceMalloc(int** ptr, size_t size, cudaStream_t, bool is_random_initialize); +template void deviceMalloc(bool** ptr, size_t size, cudaStream_t, bool is_random_initialize); +template void deviceMalloc(char** ptr, size_t size, cudaStream_t, bool is_random_initialize); +template void deviceMalloc(int8_t** ptr, size_t size, cudaStream_t, bool is_random_initialize); #ifdef ENABLE_FP8 -template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, cudaStream_t, bool is_random_initialize); #endif template -void deviceMemSetZero(T* ptr, size_t size) -{ - check_cuda_error(cudaMemset(static_cast(ptr), 0, sizeof(T) * size)); -} - -template void deviceMemSetZero(float* ptr, size_t size); -template void deviceMemSetZero(half* ptr, size_t size); -template void deviceMemSetZero(int* ptr, size_t size); -template void deviceMemSetZero(uint32_t* ptr, size_t size); -template void deviceMemSetZero(bool* ptr, size_t size); -#ifdef ENABLE_FP8 -template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size); -#endif -#ifdef ENABLE_BF16 -template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size); -#endif - -template -void deviceFree(T*& ptr) +void deviceFree(T*& ptr, cudaStream_t st) { if (ptr != NULL) { - check_cuda_error(cudaFree(ptr)); + check_cuda_error(cudaFreeAsync(ptr, st)); ptr = NULL; } } -template void deviceFree(float*& ptr); -template void deviceFree(half*& ptr); +template void deviceFree(float*& ptr, cudaStream_t); +template void deviceFree(half*& ptr, cudaStream_t); #ifdef ENABLE_BF16 -template void deviceFree(__nv_bfloat16*& ptr); +template void deviceFree(__nv_bfloat16*& ptr, cudaStream_t); #endif -template void deviceFree(unsigned short*& ptr); -template void deviceFree(int*& ptr); -template void deviceFree(bool*& ptr); -template void deviceFree(char*& ptr); -template void deviceFree(int8_t*& ptr); +template void deviceFree(unsigned short*& ptr, cudaStream_t); +template void deviceFree(int*& ptr, cudaStream_t); +template void deviceFree(bool*& ptr, cudaStream_t); +template void deviceFree(char*& ptr, cudaStream_t); +template void deviceFree(int8_t*& ptr, cudaStream_t); +template void deviceFree(void*& ptr, cudaStream_t); #ifdef ENABLE_FP8 -template void deviceFree(__nv_fp8_e4m3*& ptr); +template void deviceFree(__nv_fp8_e4m3*& ptr, cudaStream_t); #endif +namespace { + +template +__global__ void fill_kernel(T* devptr, size_t size, T value) +{ + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + for (size_t i = idx; i < size; i += blockDim.x * gridDim.x) { + devptr[i] = value; + } +} + +} // namespace + template void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) { - T* arr = new T[size]; - std::fill(arr, arr + size, value); - check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); - delete[] arr; + constexpr int threads = 512; + const int blocks = (size + threads - 1) / threads; + fill_kernel<<>>(devptr, size, value); } template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream); @@ -280,23 +274,23 @@ __global__ void cuda_random_uniform_kernel(char* buffer, const size_t size } template -void cudaRandomUniform(T* buffer, const size_t size) +void cudaRandomUniform(T* buffer, const size_t size, cudaStream_t st) { static int seq_offset = 0; - cuda_random_uniform_kernel<<<256, 256>>>(buffer, size, seq_offset); + cuda_random_uniform_kernel<<<256, 256, 0, st>>>(buffer, size, seq_offset); seq_offset += 256 * 256; } -template void cudaRandomUniform(float* buffer, const size_t size); -template void cudaRandomUniform(half* buffer, const size_t size); +template void cudaRandomUniform(float* buffer, const size_t size, cudaStream_t); +template void cudaRandomUniform(half* buffer, const size_t size, cudaStream_t); #ifdef ENABLE_BF16 -template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size); +template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size, cudaStream_t); #endif -template void cudaRandomUniform(int* buffer, const size_t size); -template void cudaRandomUniform(bool* buffer, const size_t size); -template void cudaRandomUniform(char* buffer, const size_t size); +template void cudaRandomUniform(int* buffer, const size_t size, cudaStream_t); +template void cudaRandomUniform(bool* buffer, const size_t size, cudaStream_t); +template void cudaRandomUniform(char* buffer, const size_t size, cudaStream_t); #ifdef ENABLE_FP8 -template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); +template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size, cudaStream_t); #endif // loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or @@ -366,10 +360,10 @@ int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filenam } else { T_IN* ptr_2 = nullptr; - deviceMalloc(&ptr_2, host_array.size(), false); + deviceMalloc(&ptr_2, host_array.size(), nullptr, false); cudaH2Dcpy(ptr_2, host_array.data(), host_array.size()); invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size()); - deviceFree(ptr_2); + deviceFree(ptr_2, nullptr); } return 0; } diff --git a/src/turbomind/utils/memory_utils.h b/src/turbomind/utils/memory_utils.h index bb7a4f9c0..03a0ef7b3 100644 --- a/src/turbomind/utils/memory_utils.h +++ b/src/turbomind/utils/memory_utils.h @@ -23,16 +23,13 @@ namespace turbomind { template -void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); +void deviceMalloc(T** ptr, size_t size, cudaStream_t st, bool is_random_initialize = false); template -void deviceMemSetZero(T* ptr, size_t size); +void deviceFree(T*& ptr, cudaStream_t st); template -void deviceFree(T*& ptr); - -template -void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = {}); template void cudaD2Hcpy(T* tgt, const T* src, const size_t size); @@ -44,10 +41,10 @@ template void cudaD2Dcpy(T* tgt, const T* src, const size_t size); template -void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL); +void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = {}); template -void cudaRandomUniform(T* buffer, const size_t size); +void cudaRandomUniform(T* buffer, const size_t size, cudaStream_t stream = {}); template int loadWeightFromBin(T* ptr, From 01f82e09c11b6866b8ebe862de2595ebe87e9733 Mon Sep 17 00:00:00 2001 From: zhabuye <74179177+zhabuye@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:37:29 +0800 Subject: [PATCH 03/14] Add Ascend installation adapter (#2817) --- requirements/runtime_ascend.txt | 22 ++++++++++++++++++++++ requirements_ascend.txt | 4 ++++ setup.py | 22 ++++++++++++++++++---- 3 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 requirements/runtime_ascend.txt create mode 100644 requirements_ascend.txt diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt new file mode 100644 index 000000000..d87748e39 --- /dev/null +++ b/requirements/runtime_ascend.txt @@ -0,0 +1,22 @@ +accelerate>=0.29.3 +dlinfer-ascend +einops +fastapi +fire +mmengine-lite +numpy<2.0.0 +openai +outlines<0.1.0 +peft<=0.11.1 +pillow +protobuf +pydantic>2.0.0 +pynvml +safetensors +sentencepiece +shortuuid +tiktoken +torch<=2.4.0,>=2.0.0 +torchvision<=0.19.0,>=0.15.0 +transformers +uvicorn diff --git a/requirements_ascend.txt b/requirements_ascend.txt new file mode 100644 index 000000000..e844853ab --- /dev/null +++ b/requirements_ascend.txt @@ -0,0 +1,4 @@ +-r requirements/build.txt +-r requirements/runtime_ascend.txt +-r requirements/lite.txt +-r requirements/serve.txt diff --git a/setup.py b/setup.py index 32a69c600..7a08ac791 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,14 @@ from setuptools import find_packages, setup +npu_available = False +try: + import torch_npu + + npu_available = torch_npu.npu.is_available() +except ImportError: + pass + pwd = os.path.dirname(__file__) version_file = 'lmdeploy/version.py' @@ -145,11 +153,17 @@ def gen_packages_items(): include_package_data=True, setup_requires=parse_requirements('requirements/build.txt'), tests_require=parse_requirements('requirements/test.txt'), - install_requires=parse_requirements('requirements/runtime.txt'), + install_requires=parse_requirements( + 'requirements/runtime_ascend.txt' + if npu_available else 'requirements/runtime.txt'), extras_require={ - 'all': parse_requirements('requirements.txt'), - 'lite': parse_requirements('requirements/lite.txt'), - 'serve': parse_requirements('requirements/serve.txt') + 'all': + parse_requirements('requirements_ascend.txt' + if npu_available else 'requirements.txt'), + 'lite': + parse_requirements('requirements/lite.txt'), + 'serve': + parse_requirements('requirements/serve.txt') }, has_ext_modules=check_ext_modules, classifiers=[ From 0b6dd1f23aa9b2239fc6d9c24314ee25bec3990c Mon Sep 17 00:00:00 2001 From: zhulinJulia24 <145004780+zhulinJulia24@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:54:46 +0800 Subject: [PATCH 04/14] [CI] add more testcase for mllm models (#2791) * update * update * update * update * update * update * update * update * update --- autotest/config-v100.yaml | 16 +- autotest/config.yaml | 20 +- .../test_pipeline_chat_pytorch_llm.py | 2 - .../test_pipeline_chat_pytorch_mllm.py | 4 - .../test_pipeline_chat_turbomind_llm.py | 2 - .../test_pipeline_chat_turbomind_mllm.py | 4 - .../test_restful_chat_hf_pytorch_llm.py | 3 +- .../test_restful_chat_hf_pytorch_mllm.py | 3 +- .../test_restful_chat_hf_turbomind_llm.py | 3 +- .../test_restful_chat_hf_turbomind_mllm.py | 3 +- autotest/utils/pipeline_chat.py | 348 ++++++++++++++++++ autotest/utils/run_restful_chat.py | 15 +- docs/en/supported_models/supported_models.md | 4 +- .../supported_models/supported_models.md | 4 +- 14 files changed, 401 insertions(+), 30 deletions(-) diff --git a/autotest/config-v100.yaml b/autotest/config-v100.yaml index 41216cb73..507f81ceb 100644 --- a/autotest/config-v100.yaml +++ b/autotest/config-v100.yaml @@ -1,4 +1,5 @@ model_path: /nvme/qa_test_models +resource_path: /nvme/qa_test_models/resource dst_path: /nvme/qa_test_models/autotest_model log_path: /nvme/qa_test_models/autotest_model/log benchmark_path: /nvme/qa_test_models/benchmark-reports @@ -100,12 +101,22 @@ turbomind_quatization: - meta-llama/Meta-Llama-3-8B-Instruct - internlm/internlm-xcomposer2d5-7b - OpenGVLab/Mini-InternVL-Chat-2B-V1-5 + - Qwen/Qwen2-VL-2B-Instruct + - Qwen/Qwen2-VL-7B-Instruct - mistralai/Mistral-7B-Instruct-v0.3 - THUDM/glm-4-9b-chat + - deepseek-ai/deepseek-coder-1.3b-instruct + - codellama/CodeLlama-7b-Instruct-hf gptq: - internlm/internlm2_5-7b-chat no_kvint4: - openbmb/MiniCPM-V-2_6 + - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2-7B-Instruct-AWQ + - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2.5-0.5B-Instruct + - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2-7B-Instruct-GPTQ-Int4 no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat @@ -120,6 +131,10 @@ pytorch_quatization: no_kvint4: - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-4B + - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2-VL-2B-Instruct + - Qwen/Qwen2-VL-7B-Instruct - deepseek-ai/DeepSeek-V2-Lite-Chat - microsoft/Phi-3-mini-4k-instruct - microsoft/Phi-3-vision-128k-instruct @@ -128,7 +143,6 @@ pytorch_quatization: no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat - longtext_model: - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct diff --git a/autotest/config.yaml b/autotest/config.yaml index 88ca7c312..b4fd4e171 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -1,4 +1,5 @@ model_path: /nvme/qa_test_models +resource_path: /nvme/qa_test_models/resource dst_path: /nvme/qa_test_models/autotest_model log_path: /nvme/qa_test_models/autotest_model/log benchmark_path: /nvme/qa_test_models/benchmark-reports @@ -18,6 +19,7 @@ tp_config: Qwen2-7B-Instruct-GPTQ-Int4: 2 InternVL2-40B: 2 MiniCPM-V-2_6: 2 + Qwen2.5-72B-Instruct: 4 turbomind_chat_model: - meta-llama/Llama-3.2-1B-Instruct @@ -164,7 +166,11 @@ pytorch_base_model: turbomind_quatization: no_awq: + - Qwen/Qwen1.5-MoE-A2.7B-Chat + - Qwen/Qwen2-VL-2B-Instruct + - Qwen/Qwen2-VL-7B-Instruct - mistralai/Mistral-7B-Instruct-v0.3 + - mistralai/Mistral-Nemo-Instruct-2407 - deepseek-ai/deepseek-coder-1.3b-instruct - deepseek-ai/DeepSeek-V2-Lite-Chat - codellama/CodeLlama-7b-Instruct-hf @@ -172,6 +178,12 @@ turbomind_quatization: - internlm/internlm2_5-7b-chat no_kvint4: - openbmb/MiniCPM-V-2_6 + - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2-7B-Instruct-AWQ + - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2.5-0.5B-Instruct + - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2-7B-Instruct-GPTQ-Int4 no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat @@ -203,6 +215,10 @@ pytorch_quatization: no_kvint4: - OpenGVLab/InternVL2-1B - OpenGVLab/InternVL2-4B + - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2-1.5B-Instruct + - Qwen/Qwen2-VL-2B-Instruct + - Qwen/Qwen2-VL-7B-Instruct - deepseek-ai/DeepSeek-V2-Lite-Chat - microsoft/Phi-3-mini-4k-instruct - microsoft/Phi-3-vision-128k-instruct @@ -211,7 +227,6 @@ pytorch_quatization: no_kvint8: - deepseek-ai/DeepSeek-V2-Lite-Chat - longtext_model: - meta-llama/Meta-Llama-3-1-8B-Instruct - meta-llama/Meta-Llama-3-8B-Instruct @@ -227,7 +242,8 @@ benchmark_model: - internlm/internlm2_5-7b-chat - internlm/internlm2_5-20b-chat - THUDM/glm-4-9b-chat - - Qwen/Qwen2-7B-Instruct + - Qwen/Qwen2.5-7B-Instruct + - Qwen/Qwen2.5-72B-Instruct - mistralai/Mistral-7B-Instruct-v0.3 - mistralai/Mixtral-8x7B-Instruct-v0.1 - deepseek-ai/DeepSeek-V2-Lite-Chat diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py index a828e17a0..58674fa17 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py @@ -67,8 +67,6 @@ def test_pipeline_chat_pytorch_tp2(config, common_case_config, model, exclude_dup=True)) def test_pipeline_chat_kvint4_tp1(config, common_case_config, model, worker_id): - if 'Qwen2' in model: - return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) spawn_context = get_context('spawn') diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py index 276ced5bc..8403ced94 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py @@ -50,8 +50,6 @@ def test_pipeline_chat_tp2(config, model, worker_id): quant_policy=4, model_type='vl_model')) def test_pipeline_chat_kvint4_tp1(config, model, worker_id): - if 'Qwen2' in model: - return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) spawn_context = get_context('spawn') @@ -70,8 +68,6 @@ def test_pipeline_chat_kvint4_tp1(config, model, worker_id): quant_policy=4, model_type='vl_model')) def test_pipeline_chat_kvint4_tp2(config, model, worker_id): - if 'Qwen2' in model: - return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py index 17560e754..d1865175c 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py @@ -56,8 +56,6 @@ def test_pipeline_chat_tp2(config, common_case_config, model, worker_id): @pytest.mark.parametrize('model', get_all_model_list(tp_num=1, quant_policy=4)) def test_pipeline_chat_kvint4_tp1(config, common_case_config, model, worker_id): - if 'Qwen2' in model: - return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) spawn_context = get_context('spawn') diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py index 8f1bc7d8b..8c845fa77 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py +++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py @@ -50,8 +50,6 @@ def test_pipeline_chat_tp2(config, model, worker_id): quant_policy=4, model_type='vl_model')) def test_pipeline_chat_kvint4_tp1(config, model, worker_id): - if 'Qwen2' in model: - return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) spawn_context = get_context('spawn') @@ -70,8 +68,6 @@ def test_pipeline_chat_kvint4_tp1(config, model, worker_id): quant_policy=4, model_type='vl_model')) def test_pipeline_chat_kvint4_tp2(config, model, worker_id): - if 'Qwen2' in model: - return # kvint4 for qwen2 is not support if 'gw' in worker_id: os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, tp_num=2) diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py index ab1f5595a..fc95e288c 100644 --- a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py +++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py @@ -67,8 +67,7 @@ def getKvintModelList(tp_num, quant_policy): 'tp_num': tp_num, 'extra': f'--quant-policy {quant_policy}' } for item in get_torch_model_list( - tp_num, quant_policy=quant_policy, exclude_dup=True) - if 'qwen2' not in item.lower() or quant_policy == 8] + tp_num, quant_policy=quant_policy, exclude_dup=True)] @pytest.mark.order(7) diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py index b210733db..bf20c45e6 100644 --- a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py +++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py @@ -60,8 +60,7 @@ def getKvintModelList(tp_num, quant_policy: int = None): 'tp_num': tp_num, 'extra': f'--quant-policy {quant_policy}' } for item in get_torch_model_list( - tp_num, quant_policy=quant_policy, model_type='vl_model') - if 'qwen2' not in item.lower() or quant_policy == 8] + tp_num, quant_policy=quant_policy, model_type='vl_model')] @pytest.mark.order(7) diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py index 91e65ee51..1c9131b32 100644 --- a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py +++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py @@ -66,8 +66,7 @@ def getKvintModelList(tp_num, quant_policy): 'cuda_prefix': None, 'tp_num': tp_num, 'extra': f'--quant-policy {quant_policy}' - } for item in get_all_model_list(tp_num, quant_policy=quant_policy) - if 'qwen2' not in item.lower() or quant_policy == 8] + } for item in get_all_model_list(tp_num, quant_policy=quant_policy)] @pytest.mark.order(7) diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py index 091e18e6e..641f2f760 100644 --- a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py +++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py @@ -60,8 +60,7 @@ def getKvintModelList(tp_num, quant_policy: int = None): 'tp_num': tp_num, 'extra': f'--quant-policy {quant_policy}' } for item in get_all_model_list( - tp_num, quant_policy=quant_policy, model_type='vl_model') - if 'qwen2' not in item.lower() or quant_policy == 8] + tp_num, quant_policy=quant_policy, model_type='vl_model')] @pytest.mark.order(7) diff --git a/autotest/utils/pipeline_chat.py b/autotest/utils/pipeline_chat.py index 562a707ef..023e4ac14 100644 --- a/autotest/utils/pipeline_chat.py +++ b/autotest/utils/pipeline_chat.py @@ -3,7 +3,10 @@ from subprocess import PIPE import allure +import numpy as np import torch +from decord import VideoReader, cpu +from PIL import Image from pytest_assume.plugin import assume from utils.get_run_config import get_model_name, get_tp_num from utils.rule_condition_assert import assert_result @@ -13,6 +16,7 @@ from lmdeploy.utils import is_bf16_supported from lmdeploy.vl import load_image from lmdeploy.vl.constants import IMAGE_TOKEN +from lmdeploy.vl.utils import encode_image_base64 def run_pipeline_chat_test(config, @@ -275,6 +279,12 @@ def assert_pipeline_single_element(output, PIC1 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' # noqa E501 PIC2 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg' # noqa E501 +PIC_BEIJING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg' # noqa E501 +PIC_CHONGQING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg' # noqa E501 +PIC_REDPANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg' # noqa E501 +PIC_PANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg' # noqa E501 +DESC = 'What are the similarities and differences between these two images.' # noqa E501 +DESC_ZH = '两张图有什么相同和不同的地方.' # noqa E501 def run_pipeline_vl_chat_test(config, @@ -386,12 +396,350 @@ def run_pipeline_vl_chat_test(config, ', reason: Multi-turn example: ski not in ' + sess.response.text + '\n') + if 'internvl' in model_case.lower(): + internvl_vl_testcase(config, pipe, file) + internvl_vl_testcase(config, pipe, file, 'cn') + if 'minicpm' in model_case.lower(): + MiniCPM_vl_testcase(config, pipe, file) + if 'qwen' in model_case.lower(): + Qwen_vl_testcase(config, pipe, file) + file.close() del pipe torch.cuda.empty_cache() +def internvl_vl_testcase(config, pipe, file, lang='en'): + if lang == 'cn': + description = DESC_ZH + else: + description = DESC + # multi-image multi-round conversation, combined images + messages = [ + dict(role='user', + content=[ + dict(type='text', + text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\n{description}'), + dict(type='image_url', + image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)), + dict(type='image_url', + image_url=dict(max_dynamic_patch=12, url=PIC_PANDA)) + ]) + ] + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: combined images: panda not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + messages.append(dict(role='user', content=description)) + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: combined images second: panda not in ' + + response.text + '\n') + + # multi-image multi-round conversation, separate images + messages = [ + dict( + role='user', + content=[ + dict( + type='text', + text=f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\n' + + # noqa E251,E501 + description), + dict(type='image_url', + image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)), + dict(type='image_url', + image_url=dict(max_dynamic_patch=12, url=PIC_PANDA)) + ]) + ] + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: separate images: panda not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + messages.append(dict(role='user', content=description)) + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: separate images second: panda not in ' + + response.text + '\n') + + # video multi-round conversation + def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array([ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(num_segments) + ]) + return frame_indices + + def load_video(video_path, bound=None, num_segments=32): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + frame_indices = get_index(bound, + fps, + max_frame, + first_idx=0, + num_segments=num_segments) + imgs = [] + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') + imgs.append(img) + return imgs + + resource_path = config.get('resource_path') + video_path = resource_path + '/red-panda.mp4' + imgs = load_video(video_path, num_segments=8) + + question = '' + for i in range(len(imgs)): + question = question + f'Frame{i+1}: {IMAGE_TOKEN}\n' + + if lang == 'cn': + question += '小熊猫在做什么?' + else: + question += 'What is the red panda doing?' + + content = [{'type': 'text', 'text': question}] + for img in imgs: + content.append({ + 'type': 'image_url', + 'image_url': { + 'max_dynamic_patch': 1, + 'url': f'data:image/jpeg;base64,{encode_image_base64(img)}' + } + }) + + messages = [dict(role='user', content=content)] + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: video images: red panda not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + if lang == 'cn': + messages.append(dict(role='user', content='描述视频详情,不要重复')) + else: + messages.append( + dict(role='user', + content='Describe this video in detail. Don\'t repeat.')) + response = pipe(messages) + result = 'red panda' in response.text.lower( + ) or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: video images: red panda not in ' + + response.text + '\n') + + +def llava_vl_testcase(config, pipe, file): + # multi-image multi-round conversation, combined images + messages = [ + dict(role='user', + content=[ + dict(type='text', text='Describe the two images in detail.'), + dict(type='image_url', image_url=dict(url=PIC_BEIJING)), + dict(type='image_url', image_url=dict(url=PIC_CHONGQING)) + ]) + ] + response = pipe(messages) + result = 'buildings' in response.text.lower( + ) or '楼' in response.text.lower() or 'skyline' in response.text.lower( + ) or 'cityscape' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: combined images: buildings not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + messages.append(dict(role='user', content=DESC)) + response = pipe(messages) + result = 'buildings' in response.text.lower( + ) or '楼' in response.text.lower() or 'skyline' in response.text.lower( + ) or 'cityscape' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: combined images second: buildings not in ' + + response.text + '\n') + + +def MiniCPM_vl_testcase(config, pipe, file): + # Chat with multiple images + messages = [ + dict(role='user', + content=[ + dict(type='text', text='Describe the two images in detail.'), + dict(type='image_url', + image_url=dict(max_slice_nums=9, url=PIC_REDPANDA)), + dict(type='image_url', + image_url=dict(max_slice_nums=9, url=PIC_PANDA)) + ]) + ] + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: multiple images: panda not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + messages.append(dict(role='user', content=DESC)) + response = pipe(messages) + result = 'panda' in response.text.lower() or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: multiple images second: panda not in ' + + response.text + '\n') + + # In-context few-shot learning + EXAMPLE1 = 'https://github.com/user-attachments/assets/405d9147-95f6-4f78-8879-606a0aed6707' # noqa E251,E501 + EXAMPLE2 = 'https://github.com/user-attachments/assets/9f2c6ed9-2aa5-4189-9c4f-0b9753024ba1' # noqa E251,E501 + EXAMPLE3 = 'https://github.com/user-attachments/assets/f335b507-1957-4c22-84ae-ed69ff79df38' # noqa E251,E501 + question = 'production date' + messages = [ + dict(role='user', + content=[ + dict(type='text', text=question), + dict(type='image_url', image_url=dict(url=EXAMPLE1)), + ]), + dict(role='assistant', content='2021.08.29'), + dict(role='user', + content=[ + dict(type='text', text=question), + dict(type='image_url', image_url=dict(url=EXAMPLE2)), + ]), + dict(role='assistant', content='1999.05.15'), + dict(role='user', + content=[ + dict(type='text', text=question), + dict(type='image_url', image_url=dict(url=EXAMPLE3)), + ]) + ] + response = pipe(messages) + result = '2021' in response.text.lower() or '14' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: in context learning: 2021 or 14 not in ' + + response.text + '\n') + + # Chat with video + MAX_NUM_FRAMES = 64 # if cuda OOM set a smaller number + + def encode_video(video_path): + + def uniform_sample(length, n): + gap = len(length) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [length[i] for i in idxs] + + vr = VideoReader(video_path, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) # FPS + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if len(frame_idx) > MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + resource_path = config.get('resource_path') + video_path = resource_path + '/red-panda.mp4' + frames = encode_video(video_path) + question = 'Describe the video' + + content = [dict(type='text', text=question)] + for frame in frames: + content.append( + dict(type='image_url', + image_url=dict( + use_image_id=False, + max_slice_nums=2, + url=f'data:image/jpeg;base64,{encode_image_base64(frame)}' + ))) + + messages = [dict(role='user', content=content)] + response = pipe(messages) + result = 'red panda' in response.text.lower( + ) or '熊猫' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: video example: panda not in ' + response.text + + '\n') + + +def Qwen_vl_testcase(config, pipe, file): + # multi-image multi-round conversation, combined images + messages = [ + dict(role='user', + content=[ + dict(type='text', text='Describe the two images in detail.'), + dict(type='image_url', image_url=dict(url=PIC_BEIJING)), + dict(type='image_url', image_url=dict(url=PIC_CHONGQING)) + ]) + ] + response = pipe(messages) + result = 'buildings' in response.text.lower( + ) or '楼' in response.text.lower() or 'skyline' in response.text.lower( + ) or 'cityscape' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: combined images: buildings not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + messages.append(dict(role='user', content=DESC)) + response = pipe(messages) + result = 'buildings' in response.text.lower( + ) or '楼' in response.text.lower() or 'skyline' in response.text.lower( + ) or 'cityscape' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: combined images second: buildings not in ' + + response.text + '\n') + + # image resolution for performance boost + min_pixels = 64 * 28 * 28 + max_pixels = 64 * 28 * 28 + messages = [ + dict(role='user', + content=[ + dict(type='text', text='Describe the two images in detail.'), + dict(type='image_url', + image_url=dict(min_pixels=min_pixels, + max_pixels=max_pixels, + url=PIC_BEIJING)), + dict(type='image_url', + image_url=dict(min_pixels=min_pixels, + max_pixels=max_pixels, + url=PIC_CHONGQING)) + ]) + ] + response = pipe(messages) + result = 'ski' in response.text.lower() or '滑雪' in response.text.lower() + result = 'buildings' in response.text.lower( + ) or '楼' in response.text.lower() or 'skyline' in response.text.lower( + ) or 'cityscape' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: performance boost: buildings not in ' + + response.text + '\n') + + messages.append(dict(role='assistant', content=response.text)) + messages.append(dict(role='user', content=DESC)) + response = pipe(messages) + result = 'buildings' in response.text.lower( + ) or '楼' in response.text.lower() or 'skyline' in response.text.lower( + ) or 'cityscape' in response.text.lower() + file.writelines('result:' + str(result) + + ', reason: performance boost second: buildings not in ' + + response.text + '\n') + + def assert_pipeline_vl_chat_log(config, model_case, worker_id): log_path = config.get('log_path') diff --git a/autotest/utils/run_restful_chat.py b/autotest/utils/run_restful_chat.py index 77af1975b..082a61bcd 100644 --- a/autotest/utils/run_restful_chat.py +++ b/autotest/utils/run_restful_chat.py @@ -282,6 +282,7 @@ def get_model(url): PIC = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' # noqa E501 +PIC2 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg' # noqa E501 def run_vl_testcase(config, port: int = DEFAULT_PORT): @@ -307,6 +308,11 @@ def run_vl_testcase(config, port: int = DEFAULT_PORT): 'image_url': { 'url': PIC, }, + }, { + 'type': 'image_url', + 'image_url': { + 'url': PIC2, + }, }], }] @@ -315,8 +321,6 @@ def run_vl_testcase(config, port: int = DEFAULT_PORT): temperature=0.8, top_p=0.8) file.writelines(str(response).lower() + '\n') - assert 'tiger' in str(response).lower() or '虎' in str( - response).lower(), response api_client = APIClient(http_url) model_name = api_client.available_models[0] @@ -324,7 +328,12 @@ def run_vl_testcase(config, port: int = DEFAULT_PORT): messages=prompt_messages): continue file.writelines(str(item) + '\n') - assert 'tiger' in str(item).lower() or '虎' in str(item).lower(), item allure.attach.file(restful_log, attachment_type=allure.attachment_type.TEXT) + + assert 'tiger' in str(response).lower() or '虎' in str( + response).lower() or 'ski' in str(response).lower() or '滑雪' in str( + response).lower(), response + assert 'tiger' in str(item).lower() or '虎' in str(item).lower( + ) or 'ski' in str(item).lower() or '滑雪' in str(item).lower(), item diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index da5224125..cd43e79c9 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -19,7 +19,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | | Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | -| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | +| Mistral | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | @@ -36,7 +36,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes | | GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | | CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | -| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | NO | +| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | "-" means not verified yet. diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 502e91b6d..7ec36d235 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -19,7 +19,7 @@ | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | | Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | -| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | +| Mistral | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | @@ -36,7 +36,7 @@ | MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes | | GLM4 | 9B | LLM | Yes | Yes | Yes | Yes | | CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | -| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | NO | +| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No | “-” 表示还没有验证。 From 4ede6314aac338e3b141fe9c909233421d7b636f Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 29 Nov 2024 18:43:46 +0800 Subject: [PATCH 05/14] refactor turbomind (2/N) (#2818) --- CMakeLists.txt | 2 +- lmdeploy/turbomind/turbomind.py | 8 +- src/turbomind/models/llama/LlamaBatch.h | 4 +- src/turbomind/models/llama/LlamaV2.h | 6 +- src/turbomind/python/bind.cpp | 214 ++++++------- .../triton_backend/llama/LlamaTritonModel.cc | 165 +++++----- .../triton_backend/llama/LlamaTritonModel.h | 64 ++-- .../llama/LlamaTritonModelInstance.cc | 206 +++++-------- .../llama/LlamaTritonModelInstance.h | 36 +-- .../transformer_triton_backend.cpp | 52 ++-- .../transformer_triton_backend.hpp | 283 ++---------------- src/turbomind/utils/Tensor.h | 10 + src/turbomind/utils/instance_comm.h | 16 - 13 files changed, 370 insertions(+), 696 deletions(-) delete mode 100644 src/turbomind/utils/instance_comm.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ff2ac7dde..356da56f5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -304,7 +304,7 @@ link_directories( # add_subdirectory(3rdparty) add_subdirectory(src) -add_subdirectory(examples) +# add_subdirectory(examples) if(BUILD_TEST) add_subdirectory(tests/csrc) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 05bc3e400..a1b2fff94 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -358,12 +358,10 @@ def _forward_callback(self, result, ctx): self.que.put((False, result)) def _forward_thread(self, inputs): - instance_comm = self.tm_model.model_comm.create_instance_comm( - self.gpu_count) def _func(): try: - output = self.model_inst.forward(inputs, instance_comm) + output = self.model_inst.forward(inputs) except Exception as e: logger.error(f'unhandled exception: {e}') self.que.put((-1, None)) @@ -377,12 +375,10 @@ def _async_forward_callback(self, result, ctx, que: LifoQueue): que.put((False, result)) def _async_forward_thread(self, inputs, que: LifoQueue): - instance_comm = self.tm_model.model_comm.create_instance_comm( - self.gpu_count) def _func(): try: - output = self.model_inst.forward(inputs, instance_comm) + output = self.model_inst.forward(inputs) except Exception as e: logger.error(f'unhandled exception: {e}') que.put((-1, None)) diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 9c6694899..f952da6ba 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -12,7 +12,6 @@ #include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" -#include "src/turbomind/utils/instance_comm.h" #include #include #include @@ -32,8 +31,7 @@ struct SharedState { }; struct Control { - AbstractInstanceComm* comm; - Request::Callback callback; + Request::Callback callback; }; struct BatchState { diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 658282f5e..a0d35b887 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -21,6 +21,9 @@ #pragma once +#include +#include + #include "src/turbomind/layers/DynamicDecodeLayer.h" #include "src/turbomind/models/llama/Barrier.h" #include "src/turbomind/models/llama/LlamaBatch.h" @@ -31,10 +34,7 @@ #include "src/turbomind/models/llama/unified_decoder.h" #include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasMMWrapper.h" -#include "src/turbomind/utils/instance_comm.h" #include "src/turbomind/utils/nccl_utils.h" -#include -#include namespace turbomind { diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index 5a344d954..71792a4be 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -1,34 +1,38 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include "src/turbomind/python/dlpack.h" -#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" -#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" -#include "src/turbomind/utils/cuda_utils.h" -#include "src/turbomind/utils/nccl_utils.h" -#include #include +#include + +#include + #include #include #include #include #include -#include + +#include "src/turbomind/python/dlpack.h" +#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" +#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" +#include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/nccl_utils.h" namespace py = pybind11; namespace ft = turbomind; using namespace pybind11::literals; // prepare to bind container -using TensorVector = std::vector; +using TensorVector = std::vector; PYBIND11_MAKE_OPAQUE(TensorVector); -using TensorMap = std::unordered_map; +using TensorMap = std::unordered_map; PYBIND11_MAKE_OPAQUE(TensorMap); static const char kDlTensorCapsuleName[] = "dltensor"; -DLDevice getDLDevice(triton::Tensor& tensor) +DLDevice getDLDevice(ft::Tensor& tensor) { int device_id = 0; - if (tensor.where == triton::MEMORY_GPU) { + if (tensor.where == ft::MEMORY_GPU) { cudaPointerAttributes ptr_attr; cudaPointerGetAttributes(&ptr_attr, tensor.data); device_id = ptr_attr.device; @@ -37,13 +41,13 @@ DLDevice getDLDevice(triton::Tensor& tensor) DLDevice device{kDLCPU, device_id}; switch (tensor.where) { - case triton::MEMORY_CPU: + case ft::MEMORY_CPU: device.device_type = DLDeviceType::kDLCPU; break; - case triton::MEMORY_CPU_PINNED: + case ft::MEMORY_CPU_PINNED: device.device_type = DLDeviceType::kDLCUDAHost; break; - case triton::MEMORY_GPU: + case ft::MEMORY_GPU: device.device_type = DLDeviceType::kDLCUDA; break; default: @@ -53,62 +57,62 @@ DLDevice getDLDevice(triton::Tensor& tensor) return device; } -DLManagedTensor* TritonTensorToDLManagedTensor(triton::Tensor& tensor) +DLManagedTensor* TritonTensorToDLManagedTensor(ft::Tensor& tensor) { DLDevice device = getDLDevice(tensor); DLDataType data_type{0, 0, 1}; switch (tensor.type) { - case triton::TYPE_BOOL: + case ft::TYPE_BOOL: data_type.code = DLDataTypeCode::kDLBool; data_type.bits = 8; break; - case triton::TYPE_UINT8: + case ft::TYPE_UINT8: data_type.code = DLDataTypeCode::kDLUInt; data_type.bits = 8; break; - case triton::TYPE_UINT16: + case ft::TYPE_UINT16: data_type.code = DLDataTypeCode::kDLUInt; data_type.bits = 16; break; - case triton::TYPE_UINT32: + case ft::TYPE_UINT32: data_type.code = DLDataTypeCode::kDLUInt; data_type.bits = 32; break; - case triton::TYPE_UINT64: + case ft::TYPE_UINT64: data_type.code = DLDataTypeCode::kDLUInt; data_type.bits = 64; break; - case triton::TYPE_INT8: - case triton::TYPE_BYTES: + case ft::TYPE_INT8: + case ft::TYPE_BYTES: data_type.code = DLDataTypeCode::kDLInt; data_type.bits = 8; break; - case triton::TYPE_INT16: + case ft::TYPE_INT16: data_type.code = DLDataTypeCode::kDLInt; data_type.bits = 16; break; - case triton::TYPE_INT32: + case ft::TYPE_INT32: data_type.code = DLDataTypeCode::kDLInt; data_type.bits = 32; break; - case triton::TYPE_INT64: + case ft::TYPE_INT64: data_type.code = DLDataTypeCode::kDLInt; data_type.bits = 64; break; - case triton::TYPE_FP16: + case ft::TYPE_FP16: data_type.code = DLDataTypeCode::kDLFloat; data_type.bits = 16; break; - case triton::TYPE_FP32: + case ft::TYPE_FP32: data_type.code = DLDataTypeCode::kDLFloat; data_type.bits = 32; break; - case triton::TYPE_FP64: + case ft::TYPE_FP64: data_type.code = DLDataTypeCode::kDLFloat; data_type.bits = 64; break; - case triton::TYPE_BF16: + case ft::TYPE_BF16: data_type.code = DLDataTypeCode::kDLBfloat; data_type.bits = 16; break; @@ -125,78 +129,78 @@ DLManagedTensor* TritonTensorToDLManagedTensor(triton::Tensor& tensor) return new DLManagedTensor{dl_tensor, nullptr, [](DLManagedTensor* dlmt) { delete dlmt; }}; } -triton::MemoryType getMemoryType(DLDevice device) +ft::MemoryType getMemoryType(DLDevice device) { switch (device.device_type) { case DLDeviceType::kDLCUDAHost: - return triton::MemoryType::MEMORY_CPU_PINNED; + return ft::MemoryType::MEMORY_CPU_PINNED; case DLDeviceType::kDLCUDA: - return triton::MemoryType::MEMORY_GPU; + return ft::MemoryType::MEMORY_GPU; case DLDeviceType::kDLCPU: default: - return triton::MemoryType::MEMORY_CPU; + return ft::MemoryType::MEMORY_CPU; } } -triton::DataType getDataType(DLDataType data_type) +ft::DataType getDataType(DLDataType data_type) { switch (data_type.code) { case DLDataTypeCode::kDLUInt: switch (data_type.bits) { case 8: - return triton::TYPE_UINT8; + return ft::TYPE_UINT8; case 16: - return triton::TYPE_UINT16; + return ft::TYPE_UINT16; case 32: - return triton::TYPE_UINT32; + return ft::TYPE_UINT32; case 64: - return triton::TYPE_UINT64; + return ft::TYPE_UINT64; default: - return triton::TYPE_INVALID; + return ft::TYPE_INVALID; } break; case DLDataTypeCode::kDLInt: switch (data_type.bits) { case 8: - return triton::TYPE_INT8; + return ft::TYPE_INT8; case 16: - return triton::TYPE_INT16; + return ft::TYPE_INT16; case 32: - return triton::TYPE_INT32; + return ft::TYPE_INT32; case 64: - return triton::TYPE_INT64; + return ft::TYPE_INT64; default: - return triton::TYPE_INVALID; + return ft::TYPE_INVALID; } break; case DLDataTypeCode::kDLFloat: switch (data_type.bits) { case 16: - return triton::TYPE_FP16; + return ft::TYPE_FP16; case 32: - return triton::TYPE_FP32; + return ft::TYPE_FP32; case 64: - return triton::TYPE_FP64; + return ft::TYPE_FP64; default: - return triton::TYPE_INVALID; + return ft::TYPE_INVALID; } break; case DLDataTypeCode::kDLBfloat: switch (data_type.bits) { case 16: - return triton::TYPE_BF16; + return ft::TYPE_BF16; default: - return triton::TYPE_INVALID; + return ft::TYPE_INVALID; } break; case DLDataTypeCode::kDLBool: - return triton::TYPE_BOOL; + return ft::TYPE_BOOL; default: - return triton::TYPE_INVALID; + return ft::TYPE_INVALID; } } -std::shared_ptr DLManagedTensorToTritonTensor(DLManagedTensor* tensor) +std::shared_ptr DLManagedTensorToTritonTensor(DLManagedTensor* tensor) { auto& dl_tensor = tensor->dl_tensor; auto where = getMemoryType(dl_tensor.device); @@ -205,7 +209,7 @@ std::shared_ptr DLManagedTensorToTritonTensor(DLManagedTensor* t std::vector shape(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim); auto data = dl_tensor.data; - return std::make_shared(where, dtype, shape, data); + return std::make_shared(where, dtype, shape, data); } DLTensor GetDLTensor(py::object obj) @@ -270,70 +274,65 @@ PYBIND11_MODULE(_turbomind, m) // custom comm py::class_>(m, "AbstractCustomComm"); - // instance comm - py::class_(m, "AbstractInstanceComm"); - // data type - py::enum_(m, "DataType") - .value("TYPE_INVALID", triton::DataType::TYPE_INVALID) - .value("TYPE_BOOL", triton::DataType::TYPE_BOOL) - .value("TYPE_UINT8", triton::DataType::TYPE_UINT8) - .value("TYPE_UINT16", triton::DataType::TYPE_UINT16) - .value("TYPE_UINT32", triton::DataType::TYPE_UINT32) - .value("TYPE_UINT64", triton::DataType::TYPE_UINT64) - .value("TYPE_INT8", triton::DataType::TYPE_INT8) - .value("TYPE_INT16", triton::DataType::TYPE_INT16) - .value("TYPE_INT32", triton::DataType::TYPE_INT32) - .value("TYPE_INT64", triton::DataType::TYPE_INT64) - .value("TYPE_FP16", triton::DataType::TYPE_FP16) - .value("TYPE_FP32", triton::DataType::TYPE_FP32) - .value("TYPE_FP64", triton::DataType::TYPE_FP64) - .value("TYPE_BYTES", triton::DataType::TYPE_BYTES) - .value("TYPE_BF16", triton::DataType::TYPE_BF16); + py::enum_(m, "DataType") + .value("TYPE_INVALID", ft::DataType::TYPE_INVALID) + .value("TYPE_BOOL", ft::DataType::TYPE_BOOL) + .value("TYPE_UINT8", ft::DataType::TYPE_UINT8) + .value("TYPE_UINT16", ft::DataType::TYPE_UINT16) + .value("TYPE_UINT32", ft::DataType::TYPE_UINT32) + .value("TYPE_UINT64", ft::DataType::TYPE_UINT64) + .value("TYPE_INT8", ft::DataType::TYPE_INT8) + .value("TYPE_INT16", ft::DataType::TYPE_INT16) + .value("TYPE_INT32", ft::DataType::TYPE_INT32) + .value("TYPE_INT64", ft::DataType::TYPE_INT64) + .value("TYPE_FP16", ft::DataType::TYPE_FP16) + .value("TYPE_FP32", ft::DataType::TYPE_FP32) + .value("TYPE_FP64", ft::DataType::TYPE_FP64) + .value("TYPE_BYTES", ft::DataType::TYPE_BYTES) + .value("TYPE_BF16", ft::DataType::TYPE_BF16); // memory type - py::enum_(m, "MemoryType") - .value("MEMORY_CPU", triton::MemoryType::MEMORY_CPU) - .value("MEMORY_CPU_PINNED", triton::MemoryType::MEMORY_CPU_PINNED) - .value("MEMORY_GPU", triton::MemoryType::MEMORY_GPU); + py::enum_(m, "MemoryType") + .value("MEMORY_CPU", ft::MemoryType::MEMORY_CPU) + .value("MEMORY_CPU_PINNED", ft::MemoryType::MEMORY_CPU_PINNED) + .value("MEMORY_GPU", ft::MemoryType::MEMORY_GPU); // tensor - py::class_>(m, "Tensor") - .def_readonly("where", &triton::Tensor::where) - .def_readonly("type", &triton::Tensor::type) - .def_readonly("shape", &triton::Tensor::shape) - .def_readonly("data", &triton::Tensor::data) - .def(py::init([](const triton::MemoryType where, - const triton::DataType type, - const std::vector& shape, - const long data) { - auto data_ptr = reinterpret_cast(data); - return new triton::Tensor(where, type, shape, data_ptr); - })) + py::class_>(m, "Tensor") + .def_readonly("where", &ft::Tensor::where) + .def_readonly("type", &ft::Tensor::type) + .def_readonly("shape", &ft::Tensor::shape) + .def_readonly("data", &ft::Tensor::data) + .def(py::init( + [](const ft::MemoryType where, const ft::DataType type, const std::vector& shape, const long data) { + auto data_ptr = reinterpret_cast(data); + return new ft::Tensor(where, type, shape, data_ptr); + })) .def( "view", - [](triton::Tensor* self, triton::DataType new_type) { - return new triton::Tensor(self->where, new_type, self->shape, self->data); + [](ft::Tensor* self, ft::DataType new_type) { + return new ft::Tensor(self->where, new_type, self->shape, self->data); }, "new_type"_a) .def( "view", - [](triton::Tensor* self, std::vector new_shape) { - return new triton::Tensor(self->where, self->type, new_shape, self->data); + [](ft::Tensor* self, std::vector new_shape) { + return new ft::Tensor(self->where, self->type, new_shape, self->data); }, "new_shape"_a) .def( "copy_from", - [](triton::Tensor* self, py::object obj) { + [](ft::Tensor* self, py::object obj) { py::capsule cap = obj.attr("__dlpack__")(); DLManagedTensor* dlmt = static_cast(PyCapsule_GetPointer(cap.ptr(), kDlTensorCapsuleName)); auto src = DLManagedTensorToTritonTensor(dlmt); switch (self->type) { - case triton::TYPE_FP16: - case triton::TYPE_FP32: - case triton::TYPE_INT32: - case triton::TYPE_BF16: { + case ft::TYPE_FP16: + case ft::TYPE_FP32: + case ft::TYPE_INT32: + case ft::TYPE_BF16: { auto num_element = std::accumulate(src->shape.begin(), src->shape.end(), 1LL, std::multiplies()); auto num_bytes = num_element * dlmt->dl_tensor.dtype.bits / 8; @@ -348,7 +347,7 @@ PYBIND11_MODULE(_turbomind, m) "tensor"_a) .def( "__dlpack__", - [](triton::Tensor* self, long stream) { + [](ft::Tensor* self, long stream) { DLManagedTensor* dlmt = TritonTensorToDLManagedTensor(*self); return py::capsule(dlmt, kDlTensorCapsuleName, [](PyObject* obj) { DLManagedTensor* dlmt = @@ -364,7 +363,7 @@ PYBIND11_MODULE(_turbomind, m) }); }, "stream"_a = 0) - .def("__dlpack_device__", [](triton::Tensor* self) { + .def("__dlpack_device__", [](ft::Tensor* self) { auto device = getDLDevice(*self); return std::tuple(int(device.device_type), device.device_id); }); @@ -380,19 +379,19 @@ PYBIND11_MODULE(_turbomind, m) "dl_managed_tensor"_a); // transformer model instance + using ft::AbstractTransformerModelInstance; py::bind_map>(m, "TensorMap"); py::class_(m, "AbstractTransformerModelInstance") .def( "forward", - [](AbstractTransformerModelInstance* model, - std::shared_ptr input_tensors, - ft::AbstractInstanceComm* inst_comm) { return model->forward(input_tensors, inst_comm); }, + [](AbstractTransformerModelInstance* model, std::shared_ptr input_tensors) { + return model->forward(input_tensors); + }, py::call_guard(), - "input_tensors"_a, - "inst_comm"_a = nullptr) + "input_tensors"_a) .def( "register_callback", - [](AbstractTransformerModelInstance* self, triton_stream_cb_t cb, py::object ctx) { + [](AbstractTransformerModelInstance* self, ft::triton_stream_cb_t cb, py::object ctx) { self->registerCallback(cb, ctx.ptr()); }, "callback"_a, @@ -400,6 +399,8 @@ PYBIND11_MODULE(_turbomind, m) .def("unregister_callback", &AbstractTransformerModelInstance::unRegisterCallback); // transformer model + using ft::AbstractTransformerModel; + using ft::LlamaTritonModel; py::class_>(m, "AbstractTransformerModel") .def_static( "create_llama_model", @@ -463,7 +464,6 @@ PYBIND11_MODULE(_turbomind, m) return ret; }, "world_size"_a) - .def("create_instance_comm", &AbstractTransformerModel::createInstanceComm, "size"_a) .def( "create_model_instance", [](AbstractTransformerModel* model, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 1c7c5eb46..40c5ac890 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -27,17 +27,18 @@ #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/context.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/allocator.h" +#include "src/turbomind/utils/cuda_utils.h" + #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" -#include "src/turbomind/utils/allocator.h" -#include "src/turbomind/utils/cuda_utils.h" -namespace ft = turbomind; +namespace turbomind { -static std::optional get_moe_method() +static std::optional get_moe_method() { - static const auto value = []() -> std::optional { + static const auto value = []() -> std::optional { const auto p = std::getenv("TM_MOE_METHOD"); if (p) { std::string str(p); @@ -45,10 +46,10 @@ static std::optional get_moe_method() x = std::tolower(x); } if (str == "naive") { - return ft::MoeParam::kNaive; + return MoeParam::kNaive; } else if (str == "fused") { - return ft::MoeParam::kFused; + return MoeParam::kFused; } else { std::cerr << "[WARNING] unrecognised MoE method: " << str << "\n"; @@ -67,7 +68,7 @@ std::shared_ptr AbstractTransformerModel::createLlamaM } catch (const YAML::Exception& e) { std::cerr << "Error reading YAML config: " << e.what() << std::endl; - ft::FT_CHECK(false); + FT_CHECK(false); } const auto ft_instance_hyperparameter = reader["ft_instance_hyperparameter"]; @@ -91,7 +92,7 @@ std::shared_ptr AbstractTransformerModel::createLlamaM model_dir); #else TM_LOG_ERROR("[ERROR] Turbomind is not built with ENABLE_BF16"); - ft::FT_CHECK(false); + FT_CHECK(false); #endif } else { @@ -103,7 +104,7 @@ std::shared_ptr AbstractTransformerModel::createLlamaM model_dir); #else TM_LOG_ERROR("[ERROR] Turbomind is not built with ENABLE_BF32"); - ft::FT_CHECK(false); + FT_CHECK(false); #endif } return nullptr; @@ -205,10 +206,10 @@ void LlamaTritonModel::handleMissingParams() template LlamaTritonModel::~LlamaTritonModel() { - ft::FT_CHECK(weights_.size() == engines_.size()); + FT_CHECK(weights_.size() == engines_.size()); for (int device_id = 0; device_id < (int)engines_.size(); ++device_id) { // Set device id before destructing CUDA resources - ft::check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(device_id)); engines_[device_id].reset(); weights_[device_id].reset(); } @@ -222,7 +223,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, std::string config): tensor_para_size_(tensor_para_size), pipeline_para_size_(pipeline_para_size), - weights_(ft::getDeviceCount()), + weights_(getDeviceCount()), enable_custom_all_reduce_(enable_custom_all_reduce) { FT_CHECK_WITH_INFO(!(config.empty() && model_dir.empty()), "invalid init options"); @@ -242,7 +243,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, } catch (const YAML::Exception& e) { std::cerr << "Error reading YAML config: " << e.what() << std::endl; - ft::FT_CHECK(false); + FT_CHECK(false); } const auto model_reader = reader["model_config"]; @@ -305,7 +306,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, engine_param_.num_tokens_per_iter = engine_reader["num_tokens_per_iter"].as(0); engine_param_.max_prefill_iters = engine_reader["max_prefill_iters"].as(1); - lora_param_.policy = ft::getLoraPolicy(reader["lora_config"]["lora_policy"].as("")); + lora_param_.policy = getLoraPolicy(reader["lora_config"]["lora_policy"].as("")); lora_param_.r = lora_reader["lora_r"].as(0); lora_param_.scale = lora_reader["lora_scale"].as(0); lora_param_.max_wo_r = lora_reader["lora_max_wo_r"].as(0); @@ -329,75 +330,75 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, handleMissingParams(); - shared_state_ = std::make_shared(); - shared_state_->barrier = std::make_shared(tensor_para_size); + shared_state_ = std::make_shared(); + shared_state_->barrier = std::make_shared(tensor_para_size); - const auto device_count = ft::getDeviceCount(); + const auto device_count = getDeviceCount(); engines_.resize(device_count); const std::string weight_type_str = model_reader["weight_type"].as(); if (weight_type_str == "fp16" || weight_type_str == "float16") { - model_param_.weight_type = ft::WeightType::kFP16; + model_param_.weight_type = WeightType::kFP16; } else if (weight_type_str == "bf16" || weight_type_str == "bfloat16") { - model_param_.weight_type = ft::WeightType::kBF16; + model_param_.weight_type = WeightType::kBF16; } else if (weight_type_str == "fp32") { - model_param_.weight_type = ft::WeightType::kFP32; + model_param_.weight_type = WeightType::kFP32; } else if (weight_type_str == "int8") { - model_param_.weight_type = ft::WeightType::kINT8; + model_param_.weight_type = WeightType::kINT8; } else if (weight_type_str == "int4") { - model_param_.weight_type = ft::WeightType::kINT4; + model_param_.weight_type = WeightType::kINT4; } else { std::cout << "[ERROR] Unsupported weight type: '" << weight_type_str << "'\n"; - ft::FT_CHECK(0); + FT_CHECK(0); } if (auto method = get_moe_method()) { moe_param_.method = *method; } else { - moe_param_.method = ft::MoeParam::kFused; + moe_param_.method = MoeParam::kFused; } TM_LOG_INFO("%s", toString().c_str()); } template -std::unique_ptr> LlamaTritonModel::createSharedModelInstance( - int device_id, - int rank, - std::pair, std::vector> nccl_params, - std::shared_ptr custom_all_reduce_comm) +std::unique_ptr> +LlamaTritonModel::createSharedModelInstance(int device_id, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) { - ft::check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(device_id)); const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); - auto ctx = std::make_unique>(device_id); + auto ctx = std::make_unique>(device_id); - ft::NcclParam tensor_para = nccl_params.first[comms_rank]; - ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + NcclParam tensor_para = nccl_params.first[comms_rank]; + NcclParam pipeline_para = nccl_params.second[comms_rank]; - ft::FT_CHECK(tensor_para.world_size_ == tensor_para_size_); - ft::FT_CHECK(pipeline_para.world_size_ == pipeline_para_size_); + FT_CHECK(tensor_para.world_size_ == tensor_para_size_); + FT_CHECK(pipeline_para.world_size_ == pipeline_para_size_); - auto model = std::make_unique>(model_param_, // - attn_param_, - moe_param_, - lora_param_, - tensor_para, - *ctx, - engine_param_.max_batch_size, - weights_[device_id]); + auto model = std::make_unique>(model_param_, // + attn_param_, + moe_param_, + lora_param_, + tensor_para, + *ctx, + engine_param_.max_batch_size, + weights_[device_id]); - auto engine = std::make_unique>(engine_param_, // - std::move(model), - std::move(ctx), - shared_state_, - device_id); + auto engine = std::make_unique>(engine_param_, // + std::move(model), + std::move(ctx), + shared_state_, + device_id); // Wait for pinned buffers to be allocated for all ranks, otherwise tuning will hang // due to concurrent kernel launch & cudaMallocHost @@ -413,14 +414,14 @@ std::unique_ptr LlamaTritonModel::createModelInstance(int device_id, int rank, cudaStream_t stream, - std::pair, std::vector>, - std::shared_ptr) + std::pair, std::vector>, + std::shared_ptr) { - ft::check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(device_id)); - ft::FT_CHECK(engines_[device_id] != nullptr); + FT_CHECK(engines_[device_id] != nullptr); - auto allocator = std::make_unique>(device_id, false); + auto allocator = std::make_unique>(device_id, false); allocator->setStream(stream); @@ -430,12 +431,12 @@ LlamaTritonModel::createModelInstance(int device_id, template void LlamaTritonModel::createSharedWeights(int device_id, int rank) { - ft::check_cuda_error(cudaSetDevice(device_id)); + check_cuda_error(cudaSetDevice(device_id)); const int tensor_para_rank = rank % tensor_para_size_; const int pipeline_para_rank = rank / tensor_para_size_; - ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0); - weights_[device_id] = std::make_shared>( - model_param_, lora_param_, moe_param_, tensor_para_size_, tensor_para_rank); + FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0); + weights_[device_id] = + std::make_shared>(model_param_, lora_param_, moe_param_, tensor_para_size_, tensor_para_rank); // model inited with model_dir if (model_dir_ != "") { weights_[device_id]->loadModel(model_dir_); @@ -444,37 +445,41 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) } template -TensorMap LlamaTritonModel::getParams(int deviceId, int rank) +std::unordered_map LlamaTritonModel::getParams(int deviceId, int rank) { - ft::check_cuda_error(cudaSetDevice(deviceId)); + check_cuda_error(cudaSetDevice(deviceId)); + // shared_weight should be created before getParams - ft::FT_CHECK(weights_[deviceId] != nullptr); - ft::TensorMap output = weights_[deviceId]->getParams(); - TensorMap result; + FT_CHECK(weights_[deviceId] != nullptr); + + TensorMap output = weights_[deviceId]->getParams(); + + std::unordered_map result; for (auto [name, tensor] : output) { - result.emplace(name, triton::Tensor{tensor.where, tensor.type, tensor.shape, tensor.data}); + result.insert({{name, Tensor{tensor.where, tensor.type, tensor.shape, tensor.data}}}); } + return result; } template void LlamaTritonModel::processWeights(int device_id, int rank) { - ft::check_cuda_error(cudaSetDevice(device_id)); - ft::FT_CHECK(weights_[device_id] != nullptr); + check_cuda_error(cudaSetDevice(device_id)); + FT_CHECK(weights_[device_id] != nullptr); cudaDeviceProp props{}; - ft::check_cuda_error(cudaGetDeviceProperties(&props, device_id)); + check_cuda_error(cudaGetDeviceProperties(&props, device_id)); weights_[device_id]->prepare(props); - ft::sync_check_cuda_error(); + sync_check_cuda_error(); } template -void LlamaTritonModel::createEngine(int device_id, - int rank, - std::pair, std::vector> nccl_params, - std::shared_ptr custom_all_reduce_comm) +void LlamaTritonModel::createEngine(int device_id, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) { auto engine = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); @@ -515,17 +520,11 @@ std::string LlamaTritonModel::toString() } template -void LlamaTritonModel::createCustomComms( - std::vector>* custom_all_reduce_comms, int world_size) +void LlamaTritonModel::createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) { - using commDataType = typename ft::CustomARCommTypeConverter::Type; - ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); -} - -template -std::unique_ptr LlamaTritonModel::createInstanceComm(int size) -{ - return nullptr; + using commDataType = typename CustomARCommTypeConverter::Type; + initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); } template @@ -547,3 +546,5 @@ template struct LlamaTritonModel; #ifdef ENABLE_BF16 template struct LlamaTritonModel<__nv_bfloat16>; #endif + +} // namespace turbomind diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index a6c1b862a..8f473cd4c 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -31,7 +31,7 @@ #include #include -namespace ft = turbomind; +namespace turbomind { template struct LlamaTritonModel: public AbstractTransformerModel { @@ -44,27 +44,25 @@ struct LlamaTritonModel: public AbstractTransformerModel { ~LlamaTritonModel() override; std::unique_ptr - createModelInstance(int deviceId, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_params, - std::shared_ptr custom_all_reduce_comm = nullptr) override; + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) override; void createSharedWeights(int deviceId, int rank) override; - TensorMap getParams(int deviceId, int rank) override; + std::unordered_map getParams(int deviceId, int rank) override; void processWeights(int deviceId, int rank) override; - void createEngine(int device_id, - int rank, - std::pair, std::vector> nccl_params, - std::shared_ptr) override; + void createEngine(int device_id, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr) override; - void createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) override; - - std::unique_ptr createInstanceComm(int size) override; + void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; void handleMissingParams(); @@ -78,24 +76,24 @@ struct LlamaTritonModel: public AbstractTransformerModel { int getPipelineParaSize() override; private: - std::unique_ptr> - createSharedModelInstance(int deviceId, - int rank, - std::pair, std::vector> nccl_params, - std::shared_ptr custom_all_reduce_comm = nullptr); - - ft::ModelParam model_param_; - ft::AttentionParam attn_param_; - ft::MoeParam moe_param_; - ft::LoraParam lora_param_; - ft::EngineParam engine_param_; - size_t tensor_para_size_; - size_t pipeline_para_size_; - - std::shared_ptr shared_state_; + std::unique_ptr> + createSharedModelInstance(int deviceId, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr); + + ModelParam model_param_; + AttentionParam attn_param_; + MoeParam moe_param_; + LoraParam lora_param_; + EngineParam engine_param_; + size_t tensor_para_size_; + size_t pipeline_para_size_; + + std::shared_ptr shared_state_; // Weights & engine instances for the ranks - std::vector>> weights_; - std::vector>> engines_; + std::vector>> weights_; + std::vector>> engines_; bool is_fp16_; int enable_custom_all_reduce_ = 0; @@ -105,3 +103,5 @@ struct LlamaTritonModel: public AbstractTransformerModel { ffi_api_lock_ctrl_t ffi_lock_ = nullptr; }; + +} // namespace turbomind diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc index 8221f932c..976fc9cc1 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc @@ -31,78 +31,23 @@ #include #include -namespace ft = turbomind; +namespace turbomind { template -void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +void triton_stream_callback(std::unordered_map* outputs, void* ctx) { - LlamaTritonModelInstance* model = reinterpret_cast*>(ctx); - auto result = LlamaTritonModelInstance::convert_outputs(*output_tensors); - - model->stream_cb_(result, model->stream_ctx_); + LlamaTritonModelInstance* model = reinterpret_cast*>(ctx); + model->stream_cb_(std::make_shared>(*outputs), model->stream_ctx_); } template -LlamaTritonModelInstance::LlamaTritonModelInstance(ft::Engine& instance, - std::unique_ptr> allocator, - int device_id): +LlamaTritonModelInstance::LlamaTritonModelInstance(Engine& instance, + std::unique_ptr> allocator, + int device_id): device_id_{device_id}, instance_(&instance), allocator_(std::move(allocator)) { } -template -std::unordered_map LlamaTritonModelInstance::convert_inputs( - std::shared_ptr> input_tensors) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - - const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; - const size_t input_data_len = input_tensors->at("input_ids").shape[1]; - h_total_output_lengths_ = - (uint32_t*)std::realloc((void*)h_total_output_lengths_, request_batch_size * sizeof(uint32_t)); - - std::unordered_map ft_input_tensors{}; - - for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { - if (ft_input_tensors.count(t->first) == 0) { - ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); - } - } - - return ft_input_tensors; -} - -template -std::shared_ptr> -LlamaTritonModelInstance::convert_outputs(const std::unordered_map& output_tensors) -{ - TM_LOG_DEBUG(__PRETTY_FUNCTION__); - std::unordered_map* outputs_mapping = - new std::unordered_map(); - - for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { - outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); - } - - return std::shared_ptr>(outputs_mapping); -} - -template -std::shared_ptr> -LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) -{ - ft::FT_CHECK(false); - return nullptr; -} - -template -std::shared_ptr> -LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) -{ - ft::FT_CHECK(false); - return nullptr; -} - template std::string format_vector(const std::vector& vec) { @@ -118,120 +63,109 @@ std::string format_vector(const std::vector& vec) } template -std::shared_ptr> -LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors, - ft::AbstractInstanceComm* instance_comm) +std::shared_ptr> +LlamaTritonModelInstance::forward(std::shared_ptr> inputs) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); // In some cases, this is needed to trigger the creation of CUDA context, or later `cudaMallocAsync` will die - ft::check_cuda_error(cudaSetDevice(device_id_)); + check_cuda_error(cudaSetDevice(device_id_)); - FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape.size() == 2, - "input_tensors->at(\"input_ids\").shape.size() == 2"); - FT_CHECK_WITH_INFO(input_tensors->at("input_lengths").shape.size() == 1, - "input_tensors->at(\"input_lengths\").shape.size() == 1"); + FT_CHECK_WITH_INFO(inputs->at("input_ids").shape.size() == 2, "inputs->at(\"input_ids\").shape.size() == 2"); + FT_CHECK_WITH_INFO(inputs->at("input_lengths").shape.size() == 1, + "inputs->at(\"input_lengths\").shape.size() == 1"); - const uint32_t request_batch_size = input_tensors->at("input_ids").shape[0]; - const uint32_t max_request_output_len = (size_t)*std::max_element( - (int*)input_tensors->at("request_output_len").data, - (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); + const uint32_t request_batch_size = inputs->at("input_ids").shape[0]; + const uint32_t max_request_output_len = (size_t)*std::max_element((int*)inputs->at("request_output_len").data, + (int*)inputs->at("request_output_len").data + + inputs->at("request_output_len").shape[0]); // const uint32_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; - const uint32_t beam_width = - input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; + const uint32_t beam_width = inputs->count("beam_width") ? (size_t)(*(uint*)inputs->at("beam_width").data) : 1; FT_CHECK_WITH_INFO(beam_width == 1, "Beam search is not implemented"); - std::unordered_map ft_input_tensors = convert_inputs(input_tensors); + h_total_output_lengths_ = + (uint32_t*)std::realloc((void*)h_total_output_lengths_, request_batch_size * sizeof(uint32_t)); - const size_t max_input_len = input_tensors->at("input_ids").shape[1]; - const bool is_return_logits = - input_tensors->count("is_return_logits") && *(bool*)input_tensors->at("is_return_logits").data; + const size_t max_input_len = inputs->at("input_ids").shape[1]; + const bool is_return_logits = inputs->count("is_return_logits") && *(bool*)inputs->at("is_return_logits").data; const size_t vocab_size = instance_->model().vocab_size(); allocateBuffer(request_batch_size, max_input_len, beam_width, instance_->session_len(), is_return_logits); - std::unordered_map output_tensors = std::unordered_map{ + std::unordered_map outputs{ {"output_ids", - ft::Tensor{ft::MEMORY_CPU, - ft::TYPE_UINT32, - std::vector{request_batch_size, beam_width, (size_t)instance_->session_len()}, - d_output_ids_}}, + Tensor{MEMORY_CPU, + TYPE_UINT32, + std::vector{request_batch_size, beam_width, (size_t)instance_->session_len()}, + d_output_ids_}}, {"sequence_length", - ft::Tensor{ft::MEMORY_CPU, - ft::TYPE_UINT32, - std::vector{request_batch_size, beam_width}, - d_sequence_lengths_}}}; - - if (input_tensors->count("is_return_log_probs") && *((bool*)input_tensors->at("is_return_log_probs").data)) { - output_tensors.insert({"output_log_probs", - ft::Tensor{ft::MEMORY_GPU, - ft::TYPE_FP32, - std::vector{request_batch_size, beam_width, max_request_output_len}, - d_output_log_probs_}}); - output_tensors.insert({"cum_log_probs", - ft::Tensor{ft::MEMORY_GPU, - ft::TYPE_FP32, - std::vector{request_batch_size, beam_width}, - d_cum_log_probs_}}); + Tensor{MEMORY_CPU, TYPE_UINT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths_}}}; + + if (inputs->count("is_return_log_probs") && *((bool*)inputs->at("is_return_log_probs").data)) { + outputs.insert({"output_log_probs", + Tensor{MEMORY_GPU, + TYPE_FP32, + std::vector{request_batch_size, beam_width, max_request_output_len}, + d_output_log_probs_}}); + outputs.insert( + {"cum_log_probs", + Tensor{MEMORY_GPU, TYPE_FP32, std::vector{request_batch_size, beam_width}, d_cum_log_probs_}}); } - if (input_tensors->count("logprobs")) { + if (inputs->count("logprobs")) { size_t max_logprob_length = std::min((int)max_request_output_len, instance_->session_len()) + 1; h_logprob_vals_ = (float*)std::realloc( - h_logprob_vals_, sizeof(float) * request_batch_size * beam_width * max_logprob_length * ft::kMaxLogProb); - h_logprob_indexes_ = (uint32_t*)std::realloc(h_logprob_indexes_, - sizeof(uint32_t) * request_batch_size * beam_width - * max_logprob_length * ft::kMaxLogProb); - h_logprob_nums_ = (uint32_t*)std::realloc( + h_logprob_vals_, sizeof(float) * request_batch_size * beam_width * max_logprob_length * kMaxLogProb); + h_logprob_indexes_ = (uint32_t*)std::realloc( + h_logprob_indexes_, sizeof(uint32_t) * request_batch_size * beam_width * max_logprob_length * kMaxLogProb); + h_logprob_nums_ = (uint32_t*)std::realloc( h_logprob_nums_, sizeof(uint32_t) * request_batch_size * beam_width * max_logprob_length); - output_tensors.insert( - {{"logprob_vals", - ft::Tensor{ft::MEMORY_CPU, - ft::TYPE_FP32, - std::vector{request_batch_size, beam_width, max_logprob_length, ft::kMaxLogProb}, - h_logprob_vals_}}}); - - output_tensors.insert( - {{"logprob_indexes", - ft::Tensor{ft::MEMORY_CPU, - ft::TYPE_UINT32, - std::vector{request_batch_size, beam_width, max_logprob_length, ft::kMaxLogProb}, - h_logprob_indexes_}}}); - - output_tensors.insert({{"logprob_nums", - ft::Tensor{ft::MEMORY_CPU, - ft::TYPE_UINT32, - std::vector{request_batch_size, beam_width, max_logprob_length}, - h_logprob_nums_}}}); + outputs.insert({{"logprob_vals", + Tensor{MEMORY_CPU, + TYPE_FP32, + std::vector{request_batch_size, beam_width, max_logprob_length, kMaxLogProb}, + h_logprob_vals_}}}); + + outputs.insert({{"logprob_indexes", + Tensor{MEMORY_CPU, + TYPE_UINT32, + std::vector{request_batch_size, beam_width, max_logprob_length, kMaxLogProb}, + h_logprob_indexes_}}}); + + outputs.insert({{"logprob_nums", + Tensor{MEMORY_CPU, + TYPE_UINT32, + std::vector{request_batch_size, beam_width, max_logprob_length}, + h_logprob_nums_}}}); } if (is_return_logits) { - output_tensors.insert( - {"logits", - {ft::MEMORY_GPU, ft::TYPE_FP32, {request_batch_size, max_input_len, vocab_size}, d_output_logits_}}); + outputs.insert( + {{"logits", {MEMORY_GPU, TYPE_FP32, {request_batch_size, max_input_len, vocab_size}, d_output_logits_}}}); } try { - ft::Request::Callback callback; + Request::Callback callback; if (stream_cb_) { - callback = [this](std::unordered_map* outputs) { + callback = [this](std::unordered_map* outputs) { triton_stream_callback(outputs, this); }; } - ft::check_cuda_error(cudaStreamSynchronize(allocator_->returnStream())); - instance_->Submit(&output_tensors, &ft_input_tensors, {instance_comm, callback}); + check_cuda_error(cudaStreamSynchronize(allocator_->returnStream())); + + instance_->Submit(&outputs, inputs.get(), {callback}); // ! stream synced by the model before returning } catch (...) { h_exception_ = std::current_exception(); - output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); + outputs.insert({"error_message", Tensor{MEMORY_CPU, TYPE_BYTES, {1}, &h_exception_}}); } - return convert_outputs(output_tensors); + return std::make_shared>(std::move(outputs)); } template @@ -278,3 +212,5 @@ template struct LlamaTritonModelInstance; #ifdef ENABLE_BF16 template struct LlamaTritonModelInstance<__nv_bfloat16>; #endif + +} // namespace turbomind diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h index 08088c05d..2cf69b9fa 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h @@ -20,41 +20,29 @@ #pragma once +#include + #include "src/turbomind/models/llama/LlamaBatch.h" #include "src/turbomind/models/llama/LlamaV2.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" -#include -namespace ft = turbomind; +namespace turbomind { template struct LlamaTritonModelInstance: AbstractTransformerModelInstance { - LlamaTritonModelInstance(ft::Engine& instance, - std::unique_ptr> allocator, - int device_id); - ~LlamaTritonModelInstance(); - - std::shared_ptr> - forward(std::shared_ptr> input_tensors) override; + LlamaTritonModelInstance(Engine& instance, + std::unique_ptr> allocator, + int device_id); + ~LlamaTritonModelInstance() override; - std::shared_ptr> - forward(std::shared_ptr> input_tensors) override; - - std::shared_ptr> - forward(std::shared_ptr> input_tensors, - ft::AbstractInstanceComm*) override; - - static std::shared_ptr> - convert_outputs(const std::unordered_map& output_tensors); + virtual std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; private: - ft::Engine* instance_; - const std::unique_ptr> allocator_; - - std::unordered_map - convert_inputs(std::shared_ptr> input_tensors); + Engine* instance_; + const std::unique_ptr> allocator_; void allocateBuffer(const size_t request_batch_size, const size_t max_input_len, @@ -88,3 +76,5 @@ struct LlamaTritonModelInstance: AbstractTransformerModelInstance { uint32_t* h_total_output_lengths_ = nullptr; std::exception_ptr h_exception_ = nullptr; }; + +} // namespace turbomind diff --git a/src/turbomind/triton_backend/transformer_triton_backend.cpp b/src/turbomind/triton_backend/transformer_triton_backend.cpp index 16c64b17d..acf5e06e8 100644 --- a/src/turbomind/triton_backend/transformer_triton_backend.cpp +++ b/src/turbomind/triton_backend/transformer_triton_backend.cpp @@ -21,62 +21,66 @@ #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/utils/nccl_utils.h" -std::pair, std::vector> +namespace turbomind { + +std::pair, std::vector> AbstractTransformerModel::createNcclParams(const int node_id, const int device_id_start, const bool multi_node) { - const int gpu_count = ft::getDeviceCount(); + const int gpu_count = getDeviceCount(); const int tensor_para_size = getTensorParaSize(); const int pipeline_para_size = getPipelineParaSize(); const int local_comm_size = multi_node ? gpu_count : tensor_para_size * pipeline_para_size; - ft::FT_CHECK(tensor_para_size > 0 && pipeline_para_size > 0); - ft::FT_CHECK(device_id_start + (int)local_comm_size <= gpu_count); + FT_CHECK(tensor_para_size > 0 && pipeline_para_size > 0); + FT_CHECK(device_id_start + (int)local_comm_size <= gpu_count); - std::vector nccl_ids; + std::vector nccl_ids; if (tensor_para_size > 1 || pipeline_para_size > 1) { nccl_ids.resize(tensor_para_size + pipeline_para_size); if (node_id == 0) { for (uint32_t i = 0; i < nccl_ids.size(); i++) { - ft::ftNcclGetUniqueId(nccl_ids[i]); + ftNcclGetUniqueId(nccl_ids[i]); } } } - std::vector tensor_para_params(local_comm_size); - std::vector pipeline_para_params(local_comm_size); + std::vector tensor_para_params(local_comm_size); + std::vector pipeline_para_params(local_comm_size); // Don't init comm when size == 1 if (tensor_para_size > 1) { - const auto group_id = ft::ftNcclNextGroupId(); - ft::ftNcclGroupStart(); + const auto group_id = ftNcclNextGroupId(); + ftNcclGroupStart(); for (int gid = device_id_start; gid < device_id_start + local_comm_size; gid++) { int rank = node_id * gpu_count + gid - device_id_start; int tensor_para_rank = rank % tensor_para_size; int pipeline_para_rank = rank / tensor_para_size; - ft::NcclUid tensor_para_nccl_uid = nccl_ids[pipeline_para_rank]; - ft::check_cuda_error(cudaSetDevice(gid)); - ft::ftNcclCommInitRank( + NcclUid tensor_para_nccl_uid = nccl_ids[pipeline_para_rank]; + check_cuda_error(cudaSetDevice(gid)); + ftNcclCommInitRank( tensor_para_params[gid - device_id_start], tensor_para_rank, tensor_para_size, tensor_para_nccl_uid); tensor_para_params[gid - device_id_start].group_id_ = group_id; } - ft::ftNcclGroupEnd(); + ftNcclGroupEnd(); } if (pipeline_para_size > 1) { - const auto group_id = ft::ftNcclNextGroupId(); - ft::ftNcclGroupStart(); + const auto group_id = ftNcclNextGroupId(); + ftNcclGroupStart(); for (int gid = device_id_start; gid < device_id_start + local_comm_size; gid++) { int rank = node_id * gpu_count + gid - device_id_start; int tensor_para_rank = rank % tensor_para_size; int pipeline_para_rank = rank / tensor_para_size; - ft::NcclUid pipeline_para_nccl_uid = nccl_ids[pipeline_para_size + tensor_para_rank]; - ft::check_cuda_error(cudaSetDevice(gid)); - ft::ftNcclCommInitRank(pipeline_para_params[gid - device_id_start], - pipeline_para_rank, - pipeline_para_size, - pipeline_para_nccl_uid); + NcclUid pipeline_para_nccl_uid = nccl_ids[pipeline_para_size + tensor_para_rank]; + check_cuda_error(cudaSetDevice(gid)); + ftNcclCommInitRank(pipeline_para_params[gid - device_id_start], + pipeline_para_rank, + pipeline_para_size, + pipeline_para_nccl_uid); pipeline_para_params[gid - device_id_start].group_id_ = group_id; } - ft::ftNcclGroupEnd(); + ftNcclGroupEnd(); } - return std::pair, std::vector>(tensor_para_params, pipeline_para_params); + return std::pair, std::vector>(tensor_para_params, pipeline_para_params); } + +} // namespace turbomind diff --git a/src/turbomind/triton_backend/transformer_triton_backend.hpp b/src/turbomind/triton_backend/transformer_triton_backend.hpp index 066d75a78..6d49df457 100644 --- a/src/turbomind/triton_backend/transformer_triton_backend.hpp +++ b/src/turbomind/triton_backend/transformer_triton_backend.hpp @@ -30,242 +30,11 @@ #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/custom_ar_comm.h" -#include "src/turbomind/utils/instance_comm.h" #include "src/turbomind/utils/nccl_utils.h" -namespace ft = turbomind; +namespace turbomind { -namespace triton { -#ifdef USE_TRITONSERVER_DATATYPE - -#include "triton/core/tritonbackend.h" -#include "triton/core/tritonserver.h" - -#ifndef TRITONSERVER_API_VERSION_MAJOR -#error TRITONSERVER_API_VERSION_MAJOR Undefined! -#endif - -#ifndef TRITONSERVER_API_VERSION_MINOR -#error TRITONSERVER_API_VERSION_MINOR Undefined! -#endif - -#if (TRITONSERVER_API_VERSION_MAJOR == 1 && TRITONSERVER_API_VERSION_MINOR >= 17) \ - || (TRITONSERVER_API_VERSION_MAJOR > 1) -#define ENABLE_TRITON_BF16 1 -#endif - -typedef TRITONSERVER_DataType DataType; -typedef TRITONSERVER_MemoryType MemoryType; - -constexpr TRITONSERVER_DataType TYPE_INVALID = TRITONSERVER_TYPE_INVALID; -constexpr TRITONSERVER_DataType TYPE_BOOL = TRITONSERVER_TYPE_BOOL; -constexpr TRITONSERVER_DataType TYPE_UINT8 = TRITONSERVER_TYPE_UINT8; -constexpr TRITONSERVER_DataType TYPE_UINT16 = TRITONSERVER_TYPE_UINT16; -constexpr TRITONSERVER_DataType TYPE_UINT32 = TRITONSERVER_TYPE_UINT32; -constexpr TRITONSERVER_DataType TYPE_UINT64 = TRITONSERVER_TYPE_UINT64; -constexpr TRITONSERVER_DataType TYPE_INT8 = TRITONSERVER_TYPE_INT8; -constexpr TRITONSERVER_DataType TYPE_INT16 = TRITONSERVER_TYPE_INT16; -constexpr TRITONSERVER_DataType TYPE_INT32 = TRITONSERVER_TYPE_INT32; -constexpr TRITONSERVER_DataType TYPE_INT64 = TRITONSERVER_TYPE_INT64; -constexpr TRITONSERVER_DataType TYPE_FP16 = TRITONSERVER_TYPE_FP16; -constexpr TRITONSERVER_DataType TYPE_FP32 = TRITONSERVER_TYPE_FP32; -constexpr TRITONSERVER_DataType TYPE_FP64 = TRITONSERVER_TYPE_FP64; -constexpr TRITONSERVER_DataType TYPE_BYTES = TRITONSERVER_TYPE_BYTES; - -#ifdef ENABLE_TRITON_BF16 -constexpr TRITONSERVER_DataType TYPE_BF16 = TRITONSERVER_TYPE_BF16; -#endif -constexpr TRITONSERVER_MemoryType MEMORY_CPU = TRITONSERVER_MEMORY_CPU; -constexpr TRITONSERVER_MemoryType MEMORY_CPU_PINNED = TRITONSERVER_MEMORY_CPU_PINNED; -constexpr TRITONSERVER_MemoryType MEMORY_GPU = TRITONSERVER_MEMORY_GPU; - -#else - -typedef ft::DataType DataType; -typedef ft::MemoryType MemoryType; - -constexpr DataType TYPE_INVALID = ft::TYPE_INVALID; -constexpr DataType TYPE_BOOL = ft::TYPE_BOOL; -constexpr DataType TYPE_UINT8 = ft::TYPE_UINT8; -constexpr DataType TYPE_UINT16 = ft::TYPE_UINT16; -constexpr DataType TYPE_UINT32 = ft::TYPE_UINT32; -constexpr DataType TYPE_UINT64 = ft::TYPE_UINT64; -constexpr DataType TYPE_INT8 = ft::TYPE_INT8; -constexpr DataType TYPE_INT16 = ft::TYPE_INT16; -constexpr DataType TYPE_INT32 = ft::TYPE_INT32; -constexpr DataType TYPE_INT64 = ft::TYPE_INT64; -constexpr DataType TYPE_FP16 = ft::TYPE_FP16; -constexpr DataType TYPE_FP32 = ft::TYPE_FP32; -constexpr DataType TYPE_FP64 = ft::TYPE_FP64; -constexpr DataType TYPE_BYTES = ft::TYPE_BYTES; -constexpr DataType TYPE_BF16 = ft::TYPE_BF16; -constexpr MemoryType MEMORY_CPU = ft::MEMORY_CPU; -constexpr MemoryType MEMORY_CPU_PINNED = ft::MEMORY_CPU_PINNED; -constexpr MemoryType MEMORY_GPU = ft::MEMORY_GPU; - -#endif - -struct Tensor { - const MemoryType where; - const DataType type; - const std::vector shape; - const void* data; - - Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data): - where(_where), type(_type), shape(_shape), data(_data) - { - } - - static ft::DataType convertTritonTypeToFt(DataType tmp_type) - { - ft::DataType ft_data_type; - switch (tmp_type) { - case TYPE_INVALID: - ft_data_type = ft::DataType::TYPE_INVALID; - break; - case TYPE_BOOL: - ft_data_type = ft::DataType::TYPE_BOOL; - break; - case TYPE_UINT8: - ft_data_type = ft::DataType::TYPE_UINT8; - break; - case TYPE_UINT16: - ft_data_type = ft::DataType::TYPE_UINT16; - break; - case TYPE_UINT32: - ft_data_type = ft::DataType::TYPE_UINT32; - break; - case TYPE_UINT64: - ft_data_type = ft::DataType::TYPE_UINT64; - break; - case TYPE_INT8: - ft_data_type = ft::DataType::TYPE_INT8; - break; - case TYPE_INT16: - ft_data_type = ft::DataType::TYPE_INT16; - break; - case TYPE_INT32: - ft_data_type = ft::DataType::TYPE_INT32; - break; - case TYPE_INT64: - ft_data_type = ft::DataType::TYPE_INT64; - break; - case TYPE_FP16: - ft_data_type = ft::DataType::TYPE_FP16; - break; - case TYPE_FP32: - ft_data_type = ft::DataType::TYPE_FP32; - break; - case TYPE_FP64: - ft_data_type = ft::DataType::TYPE_FP64; - break; -#ifdef ENABLE_TRITON_BF16 - case TYPE_BF16: - ft_data_type = ft::DataType::TYPE_BF16; - break; -#endif - case TYPE_BYTES: - ft_data_type = ft::DataType::TYPE_BYTES; - break; - default: - FT_CHECK_WITH_INFO(false, "Unknown data type with type id: " + std::to_string(tmp_type)); - break; - } - return ft_data_type; - } - - ft::Tensor convertTritonTensorToFt() - { - ft::DataType ft_data_type = convertTritonTypeToFt(type); - ft::MemoryType ft_memory_type; - switch (where) { - case MEMORY_CPU: - ft_memory_type = ft::MemoryType::MEMORY_CPU; - break; - case MEMORY_CPU_PINNED: - ft_memory_type = ft::MemoryType::MEMORY_CPU_PINNED; - break; - case MEMORY_GPU: - ft_memory_type = ft::MemoryType::MEMORY_GPU; - break; - } - return ft::Tensor{ft_memory_type, ft_data_type, shape, data}; - } - - static Tensor convertFtTensorToTriton(ft::Tensor ft_tensor) - { - DataType triton_data_type; - switch (ft_tensor.type) { - case TYPE_INVALID: - triton_data_type = TYPE_INVALID; - break; - case TYPE_BOOL: - triton_data_type = TYPE_BOOL; - break; - case TYPE_UINT8: - triton_data_type = TYPE_UINT8; - break; - case TYPE_UINT16: - triton_data_type = TYPE_UINT16; - break; - case TYPE_UINT32: - triton_data_type = TYPE_UINT32; - break; - case TYPE_UINT64: - triton_data_type = TYPE_UINT64; - break; - case TYPE_INT8: - triton_data_type = TYPE_INT8; - break; - case TYPE_INT16: - triton_data_type = TYPE_INT16; - break; - case TYPE_INT32: - triton_data_type = TYPE_INT32; - break; - case TYPE_INT64: - triton_data_type = TYPE_INT64; - break; - case TYPE_FP16: - triton_data_type = TYPE_FP16; - break; - case TYPE_FP32: - triton_data_type = TYPE_FP32; - break; - case TYPE_FP64: - triton_data_type = TYPE_FP64; - break; -#ifdef ENABLE_TRITON_BF16 - case TYPE_BF16: - triton_data_type = TYPE_BF16; - break; -#endif - case TYPE_BYTES: - triton_data_type = TYPE_BYTES; - break; - default: - FT_CHECK_WITH_INFO(false, "Unknown data type with type id: " + std::to_string(ft_tensor.type)); - break; - } - MemoryType triton_memory_type; - switch (ft_tensor.where) { - case MEMORY_CPU: - triton_memory_type = MEMORY_CPU; - break; - case MEMORY_CPU_PINNED: - triton_memory_type = MEMORY_CPU_PINNED; - break; - case MEMORY_GPU: - triton_memory_type = MEMORY_GPU; - break; - } - return Tensor{triton_memory_type, triton_data_type, ft_tensor.shape, ft_tensor.data}; - } -}; - -} // namespace triton - -using triton_stream_cb_t = std::function>, void*)>; +using triton_stream_cb_t = std::function>, void*)>; struct AbstractTransformerModel; struct AbstractTransformerModelInstance; @@ -273,17 +42,8 @@ struct AbstractTransformerModelInstance; struct AbstractTransformerModelInstance { virtual ~AbstractTransformerModelInstance() = default; - virtual std::shared_ptr> - forward(std::shared_ptr> input_tensors) = 0; - - virtual std::shared_ptr> - forward(std::shared_ptr> input_tensors) = 0; - - virtual std::shared_ptr> - forward(std::shared_ptr> input_tensors, ft::AbstractInstanceComm*) - { - return forward(input_tensors); - } + virtual std::shared_ptr> + forward(std::shared_ptr> input_tensors) = 0; void registerCallback(triton_stream_cb_t cb, void* ctx) { @@ -301,43 +61,38 @@ struct AbstractTransformerModelInstance { void* stream_ctx_ = nullptr; }; -using TensorMap = std::unordered_map; - struct AbstractTransformerModel { static std::shared_ptr createLlamaModel(std::string model_dir); virtual ~AbstractTransformerModel() = default; - virtual std::pair, std::vector> + virtual std::pair, std::vector> createNcclParams(const int node_id, const int device_id_start = 0, const bool multi_node = false); - virtual void createCustomComms(std::vector>* custom_all_reduce_comms, - int world_size) = 0; - - virtual std::unique_ptr createInstanceComm(int size) - { - return nullptr; - } + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) = 0; virtual std::unique_ptr - createModelInstance(int deviceId, - int rank, - cudaStream_t stream, - std::pair, std::vector> nccl_params, - std::shared_ptr custom_all_reduce_comm = nullptr) = 0; + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) = 0; virtual void createSharedWeights(int deviceId, int rank) = 0; - virtual TensorMap getParams(int deviceId, int rank) = 0; + virtual std::unordered_map getParams(int deviceId, int rank) = 0; virtual void processWeights(int deviceId, int rank) = 0; - virtual void createEngine(int device_id, - int rank, - std::pair, std::vector> nccl_params, - std::shared_ptr) = 0; + virtual void createEngine(int device_id, + int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr) = 0; virtual std::string toString() = 0; virtual int getTensorParaSize() = 0; virtual int getPipelineParaSize() = 0; }; + +} // namespace turbomind diff --git a/src/turbomind/utils/Tensor.h b/src/turbomind/utils/Tensor.h index 6214f6bbc..b2b8524e0 100644 --- a/src/turbomind/utils/Tensor.h +++ b/src/turbomind/utils/Tensor.h @@ -515,6 +515,16 @@ class TensorMap { return tensor_map_.end(); } + int count(const std::string& key) const + { + return tensor_map_.count(key); + } + + bool empty() const + { + return tensor_map_.empty(); + } + std::string toString(); static TensorMap fromNpyFolder(const std::string& base_folder); void saveNpy(const std::string& base_folder); diff --git a/src/turbomind/utils/instance_comm.h b/src/turbomind/utils/instance_comm.h deleted file mode 100644 index 5a25360a0..000000000 --- a/src/turbomind/utils/instance_comm.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -namespace turbomind { - -class AbstractInstanceComm { -public: - virtual ~AbstractInstanceComm() = default; - - virtual void barrier() = 0; - - virtual void setSharedObject(void*) = 0; - - virtual void* getSharedObject() = 0; -}; - -} // namespace turbomind From ad21c4d73ac856ddc1fc96b9b54231ae266199bd Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Mon, 2 Dec 2024 13:58:27 +0800 Subject: [PATCH 06/14] add openssh-server installation in dockerfile (#2830) * add openssh-server installation in dockerfile * add sudo --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 664dc7271..caa58ee63 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,7 +13,7 @@ ARG PYTHON_VERSION=3.10 ARG TORCH_VERSION=2.3.0 ARG TORCHVISION_VERSION=0.18.0 -RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl &&\ +RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl openssh-server ssh sudo &&\ curl https://sh.rustup.rs -sSf | sh -s -- -y &&\ add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \ ninja-build rapidjson-dev libgoogle-glog-dev gdb python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ From 776677a43961cc985eb03c0197c0adf620b9ebc5 Mon Sep 17 00:00:00 2001 From: zhabuye <74179177+zhabuye@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:01:05 +0800 Subject: [PATCH 07/14] Add version restrictions in runtime_ascend.txt to ensure functionality (#2836) --- requirements/runtime_ascend.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index d87748e39..05d74bbe7 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -1,5 +1,5 @@ accelerate>=0.29.3 -dlinfer-ascend +dlinfer-ascend>=0.1.2 einops fastapi fire From b91ce9a259d3af4bba14c05b968fdf24373545d6 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:26:29 +0800 Subject: [PATCH 08/14] Fix gemma2 accuracy through the correct softcapping logic (#2842) * Fix gemma2 accuracy through the correct softcapping logic * remove debugging codes --- lmdeploy/pytorch/kernels/cuda/flashattention.py | 17 ++++++++++------- lmdeploy/pytorch/kernels/cuda/pagedattention.py | 6 ++++-- lmdeploy/pytorch/models/gemma.py | 9 ++++++++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 7521a3e2b..34a11ae03 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -49,7 +49,7 @@ def softcapping(qk, logit_softcapping: tl.constexpr): @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, - loop_start, loop_end, qk_scale, history_mask, + loop_start, loop_end, sm_scale, history_mask, kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, @@ -71,8 +71,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, qk += tl.dot(q1, k1) if causal_mask: - qk *= qk_scale + qk *= sm_scale qk = softcapping(qk, logit_softcapping) + qk = qk * tl_log2(math.e) qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) if window_size > 0: qk_mask = qk_mask and ( @@ -85,8 +86,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] elif window_size > 0: - qk *= qk_scale + qk *= sm_scale qk = softcapping(qk, logit_softcapping) + qk = qk * tl_log2(math.e) qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) qk = tl.where( qk_mask, @@ -96,11 +98,13 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] elif logit_softcapping > 0: - qk *= qk_scale + qk *= sm_scale qk = softcapping(qk, logit_softcapping) + qk = qk * tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] else: + qk_scale = sm_scale * tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) qk = qk * qk_scale - m_i_new[:, None] @@ -256,7 +260,6 @@ def _flash_prefill_fwd_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) - qk_scale = sm_scale * tl_log2(math.e) history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N @@ -270,7 +273,7 @@ def _flash_prefill_fwd_kernel( k1_ptrs, loop_start, loop_end, - qk_scale, + sm_scale, history_mask, kv_min_loc, causal_mask=False, @@ -291,7 +294,7 @@ def _flash_prefill_fwd_kernel( k1_ptrs, loop_start, loop_end, - qk_scale, + sm_scale, history_mask, kv_min_loc, causal_mask=True, diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index bbd6d3cf7..fe44ca434 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -205,11 +205,12 @@ def _fwd_grouped_split_kernel( qk += tl.dot(q, k) if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) - qk *= sm_scale * tl_log2(math.e) + qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping qk = tanh(qk) qk = qk * logit_softcapping + qk = qk * tl_log2(math.e) # NOTE: inf - inf = nan, and nan will leads to error if start_n + BLOCK_N > history_len or window_size > 0: qk_mask = history_len >= (start_n + offs_n) @@ -491,11 +492,12 @@ def _fwd_grouped_split_quant_kernel( qk += tl.dot(q, k) if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) - qk *= sm_scale * tl_log2(math.e) + qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping qk = tanh(qk) qk = qk * logit_softcapping + qk = qk * tl_log2(math.e) # NOTE: inf - inf = nan, and nan will leads to error if start_n + BLOCK_N > history_len or window_size > 0: qk_mask = history_len >= (start_n + offs_n) diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index 450767bda..ca36f1565 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -383,6 +383,8 @@ def __init__(self, bias=False, dtype=dtype, device=device) + self.final_logit_softcapping = getattr(config, + 'final_logit_softcapping', None) def forward( self, @@ -405,7 +407,12 @@ def forward( def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" - return self.lm_head(hidden_states) + logits = self.lm_head(hidden_states) + if self.final_logit_softcapping is not None: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping + return logits def get_input_embeddings(self): """get input embeddings.""" From c158d1877bc31aeb49f8e1b16536a882246bc130 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Mon, 2 Dec 2024 16:09:00 +0800 Subject: [PATCH 09/14] fix accessing before initialization (#2845) * fix accessing before initialization * fix linting --- .../models/llama/LlamaDecoderLayerWeight.cc | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 0a2a3be17..393a6a0e8 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -68,6 +68,28 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_id, tensor_para_size_(tp_size), tensor_para_rank_(tp_rank) { + self_attn_weights = LlamaAttentionWeight{hidden_units_, + size_per_head_, + head_num_, + kv_head_num_, + model.mla, + attn_bias_, + tensor_para_size_, + weight_type_, + model.group_size}; + + ffn_weights = LlamaFfnWeight{ + hidden_units_, + inter_size_, + tensor_para_size_, + weight_type_, + model.group_size, + weight_type_ == WeightType::kINT4 && is_fuse_silu_act(), + }; + + moe_weights = MoeFfnWeight{ + layer_id, moe_param, hidden_units_, weight_type_, model.group_size, tensor_para_size_, is_fuse_silu_act()}; + if (lora_param.policy == LoraPolicy::kPlora) { std::vector keys = { "attention.w_qkv", "attention.wo", "feed_forward.w1", "feed_forward.w2", "feed_forward.w3"}; @@ -106,28 +128,6 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_id, } fused_up_and_gate_ = ffn_weights.gating.lora.policy != LoraPolicy::kPlora; - - self_attn_weights = LlamaAttentionWeight{hidden_units_, - size_per_head_, - head_num_, - kv_head_num_, - model.mla, - attn_bias_, - tensor_para_size_, - weight_type_, - model.group_size}; - - ffn_weights = LlamaFfnWeight{ - hidden_units_, - inter_size_, - tensor_para_size_, - weight_type_, - model.group_size, - weight_type_ == WeightType::kINT4 && is_fuse_silu_act(), - }; - - moe_weights = MoeFfnWeight{ - layer_id, moe_param, hidden_units_, weight_type_, model.group_size, tensor_para_size_, is_fuse_silu_act()}; } template From 986ad17c173d2052cb9b6eb7a8e866cf917e6991 Mon Sep 17 00:00:00 2001 From: q yao Date: Mon, 2 Dec 2024 16:18:43 +0800 Subject: [PATCH 10/14] better kv allocate (#2814) * better allocate * update max session len --- lmdeploy/pytorch/engine/cache_engine.py | 135 +++++++++------------ lmdeploy/pytorch/engine/engine.py | 26 +++- lmdeploy/pytorch/engine/engine_instance.py | 13 +- lmdeploy/pytorch/engine/model_agent.py | 4 +- 4 files changed, 80 insertions(+), 98 deletions(-) diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 8eaa56394..e393adeed 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -54,7 +54,7 @@ def __init__( self.cache_stream = torch.cuda.Stream() assert self.cache_stream != torch.cuda.current_stream() # Initialize the events for stream synchronization. - self.events = [torch.cuda.Event() for _ in range(self.num_layers)] + self.events = torch.cuda.Event() logger.debug( f'Initialize cache engine with {cache_config.num_gpu_blocks}' @@ -156,80 +156,60 @@ def get_value_block_shape(self, local=local, ) - def allocate_gpu_cache(self): - """allocate caches on GPU.""" - gpu_cache: List[KVCache] = [] + def _allocate_cache(self, num_blocks: int, device: torch.device): + """allocate cache implement.""" key_block_shape = self.get_key_block_shape(local=True) value_block_shape = self.get_value_block_shape(local=True) - for _ in range(self.num_layers): - key_blocks = torch.empty( - size=(self.num_gpu_blocks, *key_block_shape), - dtype=self.kv_cache_dtype, - device='cuda', + num_layers = self.num_layers + kv_cache_dtype = self.kv_cache_dtype + + key_cache = torch.empty( + size=(num_layers, num_blocks, *key_block_shape), + dtype=kv_cache_dtype, + device=device, + ) + value_cache = torch.empty( + size=(num_layers, num_blocks, *value_block_shape), + dtype=kv_cache_dtype, + device=device, + ) + + output = (key_cache, value_cache) + + if self.cache_config.quant_policy in (4, 8): + dtype = self.model_config.dtype + key_sz_cache = torch.empty( + size=(num_layers, num_blocks, *key_block_shape[:-1], 2), + dtype=dtype, + device=device, ) - value_blocks = torch.empty( - size=(self.num_gpu_blocks, *value_block_shape), - dtype=self.kv_cache_dtype, - device='cuda', + val_sz_cache = torch.empty( + size=(num_layers, num_blocks, *value_block_shape[:-1], 2), + dtype=dtype, + device=device, ) - if self.cache_config.quant_policy in (4, 8): - key_scales_zeros = torch.empty( - size=(self.num_gpu_blocks, *key_block_shape[:-1], 2), - dtype=self.model_config.dtype, - device='cuda', - ) - value_scales_zeros = torch.empty( - size=(self.num_gpu_blocks, *value_block_shape[:-1], 2), - dtype=self.model_config.dtype, - device='cuda', - ) - gpu_cache.append((key_blocks, value_blocks, key_scales_zeros, - value_scales_zeros)) - else: - gpu_cache.append((key_blocks, value_blocks)) - - return gpu_cache + output = output + (key_sz_cache, val_sz_cache) + + return output + + def allocate_gpu_cache(self): + """allocate caches on GPU.""" + caches = self._allocate_cache(self.num_gpu_blocks, 'cuda') + self.full_gpu_cache = caches + self.local_gpu_cache = list(zip(*caches)) + return self.local_gpu_cache def allocate_cpu_cache(self): """allocate caches on Host.""" - cpu_cache: List[KVCache] = [] - key_block_shape = self.get_key_block_shape(local=True) - value_block_shape = self.get_value_block_shape(local=True) - - # TODO: pin memory might need be banned on wsl - pin_memory = True + caches = self._allocate_cache(self.num_gpu_blocks, 'cpu') - for _ in range(self.num_layers): - key_blocks = torch.empty( - size=(self.num_cpu_blocks, *key_block_shape), - dtype=self.kv_cache_dtype, - pin_memory=pin_memory, - ) - value_blocks = torch.empty( - size=(self.num_cpu_blocks, *value_block_shape), - dtype=self.kv_cache_dtype, - pin_memory=pin_memory, - ) - if self.cache_config.quant_policy in (4, 8): - key_scales_zeros = torch.empty( - size=(self.num_cpu_blocks, *key_block_shape[:-1], 2), - dtype=self.model_config.dtype, - pin_memory=pin_memory, - ) - value_scales_zeros = torch.empty( - size=(self.num_cpu_blocks, *value_block_shape[:-1], 2), - dtype=self.model_config.dtype, - pin_memory=pin_memory, - ) - cpu_cache.append((key_blocks, value_blocks, key_scales_zeros, - value_scales_zeros)) - else: - cpu_cache.append((key_blocks, value_blocks)) - return cpu_cache + self.full_cpu_cache = caches + self.local_cpu_cache = list(zip(*caches)) + return self.local_cpu_cache @torch.inference_mode() - def _swap(self, src: List[KVCache], dst: List[KVCache], + def _swap(self, src: List[torch.Tensor], dst: List[torch.Tensor], src_to_dst: Dict[int, int]): """Move caches from src memory to dst memory. @@ -238,18 +218,19 @@ def _swap(self, src: List[KVCache], dst: List[KVCache], dst (List[KVCache]): Destination cache. src_to_dst (Dict[int, int]): Map between src and dst. """ + BLOCKS_PER_COPY = 2 + num_copy = len(src_to_dst) + src_idx, dst_idx = list(zip(*src_to_dst.items())) + src_idx = torch.tensor(src_idx, device=src[0].device) + dst_idx = torch.tensor(dst_idx, device=dst[0].device) with torch.cuda.stream(self.cache_stream): - for i in range(self.num_layers): - src_key_cache, src_value_cache = src[i] - dst_key_cache, dst_value_cache = dst[i] - - for src_id, dst_id in src_to_dst.items(): - if isinstance(dst_key_cache[dst_id], torch.Tensor): - dst_key_cache[dst_id].copy_(src_key_cache[src_id]) - dst_value_cache[dst_id].copy_(src_value_cache[src_id]) - - event = self.events[i] - event.record(stream=self.cache_stream) + for scache, dcache in zip(src, dst): + for idx in range(0, num_copy, BLOCKS_PER_COPY): + sidx = src_idx[idx:idx + BLOCKS_PER_COPY] + didx = dst_idx[idx:idx + BLOCKS_PER_COPY] + sdata = scache[:, sidx] + dcache.index_copy_(1, didx, sdata.to(dcache.device)) + self.events.record(stream=self.cache_stream) def swap_in(self, src_to_dst: Dict[int, int]) -> None: """Move cache from Host to Device. @@ -257,7 +238,7 @@ def swap_in(self, src_to_dst: Dict[int, int]) -> None: Args: src_to_dst (Dict[int, int]): Map between src and dst. """ - self._swap(self.local_cpu_cache, self.local_gpu_cache, src_to_dst) + self._swap(self.full_cpu_cache, self.full_gpu_cache, src_to_dst) def swap_out(self, src_to_dst: Dict[int, int]) -> None: """Move cache from Device to Host. @@ -265,7 +246,7 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: Args: src_to_dst (Dict[int, int]): Map between src and dst. """ - self._swap(self.local_gpu_cache, self.local_cpu_cache, src_to_dst) + self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst) @classmethod def get_cache_block_size(cls, diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index cffe13bbd..26b507e9d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -164,6 +164,7 @@ def __init__(self, self.cache_config = cache_config self.backend_config = backend_config self.stream = self.model_agent.stream + self.max_session_len = self._get_max_session_len() self.req_manager = self._bind_request_manager() @@ -261,6 +262,20 @@ def _response(self, data=data, err_msg=err_msg)) + def _get_max_session_len(self): + """get max session len.""" + session_len = self.scheduler_config.max_session_len + max_tokens = (self.cache_config.num_gpu_blocks * + self.cache_config.block_size) + window_size = self.cache_config.window_size + if window_size > 0 and window_size <= max_tokens: + max_tokens = (1 << 63) - 1 + if session_len is None: + session_len = max_tokens + else: + session_len = min(max_tokens, session_len) + return session_len + def _on_add_session(self, reqs: Request, **kwargs): """on add session callback.""" for req in reqs: @@ -315,12 +330,11 @@ def __update_bad_words(msg): def __update_max_new_tokens(msg): """update max new tokens.""" - max_session_len = self.scheduler_config.max_session_len - if max_session_len is not None: - sampling_param = msg.sampling_param - sampling_param.max_new_tokens = min( - sampling_param.max_new_tokens, - max_session_len - msg.num_all_tokens()) + max_session_len = self.max_session_len + sampling_param = msg.sampling_param + sampling_param.max_new_tokens = min( + sampling_param.max_new_tokens, + max_session_len - msg.num_all_tokens()) for req in reqs: session_id = req.data['session_id'] diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 3e741c7ba..455ab1ccb 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -89,21 +89,10 @@ class EngineInstance: """ def __init__(self, engine: Engine): - - def __get_max_input_len(engine): - """get max input len.""" - cache_config = engine.cache_config - max_input_len = (cache_config.block_size * - cache_config.num_gpu_blocks) - window_size = cache_config.window_size - if window_size > 0 and window_size <= max_input_len: - max_input_len = (1 << 63) - 1 - return max_input_len - self.engine = engine self.req_sender = engine.req_manager.build_sender() - self.max_input_len = __get_max_input_len(self.engine) + self.max_input_len = self.engine.max_session_len def __del__(self): """Destructor.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 74938de81..2877f5937 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -120,9 +120,7 @@ def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, issued_cache_op = True if issued_cache_op: - cache_events = cache_engine.events - for event in cache_events: - event.wait() + cache_engine.events.wait() @torch.inference_mode() From 6734c71ffc0e94323854eb6ed139dbe621e71a9d Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:22:01 +0800 Subject: [PATCH 11/14] Update internvl chat template (#2832) * Add internvl2-5 chat template * fix template, using original internlm2 template --- lmdeploy/model.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 47aaaa4e8..a4355ea13 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -597,9 +597,32 @@ def match(cls, model_path: str) -> Optional[str]: path = model_path.lower() if ('internvl2' in path and 'internvl2-4b' not in path) or 'mono-internvl' in path: + if 'internvl2.5' in path or 'internvl2_5' in path: + return None return 'internvl2-internlm2' +@MODELS.register_module(name='internvl2_5') +class InternVL2_5(InternLM2Chat7B): + + def __init__( + self, + meta_instruction='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', # noqa + **kwargs): + super().__init__(meta_instruction=meta_instruction, **kwargs) + + @classmethod + def match(cls, model_path: str) -> Optional[str]: + """Return the model_name that was registered to MODELS. + + Args: + model_path (str): the model path used for matching. + """ + path = model_path.lower() + if 'internvl2.5' in path or 'internvl2_5' in path: + return 'internvl2_5' + + @MODELS.register_module(name=['internlm-xcomposer2', 'internlm-xcomposer2d5']) class InternLMXComposer2Chat7B(InternLMChat7B): """Chat template and generation parameters of InternLM-XComposer2-7b.""" From 8fbfed685f328c7fff6ec46a17dfcd0a50d2a685 Mon Sep 17 00:00:00 2001 From: q yao Date: Tue, 3 Dec 2024 11:14:35 +0800 Subject: [PATCH 12/14] profile throughput without new threads (#2826) * profile throughput without threads * optimize main loop * fix torch.event * fix python>3.11 * optimize tp * reduce cudagraph copy * optimize fill kv cache * optimize silu and mul * optimize apply rotary * remove executor * remove kernel * remove num_heads==1 --- benchmark/profile_throughput.py | 38 ++-- lmdeploy/pytorch/backends/cuda/attention.py | 5 +- lmdeploy/pytorch/engine/engine.py | 22 ++- lmdeploy/pytorch/engine/logits_process.py | 30 ++- lmdeploy/pytorch/engine/model_agent.py | 65 +------ lmdeploy/pytorch/kernels/cuda/activation.py | 105 +++------- .../kernels/cuda/apply_rotary_pos_emb.py | 43 +---- .../pytorch/kernels/cuda/fill_kv_cache.py | 182 ++++++------------ lmdeploy/pytorch/models/utils/cudagraph.py | 28 ++- tests/pytorch/engine/test_logits_process.py | 3 +- 10 files changed, 177 insertions(+), 344 deletions(-) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 58786d9c8..4f06fad4f 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import asyncio import csv import json import os import random import time from queue import Queue -from threading import Thread from typing import List, Tuple, Union import numpy as np @@ -86,15 +86,15 @@ def __init__(self, model_path: str, self.csv = csv self.pbar = None - def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, - temperature: float, top_p: float, top_k: int, - stream_output: bool): + async def _inference(self, req_queue: Queue, res_queue: Queue, + session_id: int, temperature: float, top_p: float, + top_k: int, stream_output: bool): model_inst = self.tm_model.create_instance() stats = [] # get each generated token's latency per_token_latency_stats = [] for prompt, input_seqlen, output_seqlen in iter( - req_queue.get, [None, None, None]): + req_queue.get_nowait, [None, None, None]): _per_token_latency_stats = [0] * (output_seqlen + 1) prev = time.perf_counter() n_prev_token = 0 @@ -102,7 +102,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, input_ids = self.tokenizer(prompt).input_ids state = DetokenizeState(len(input_ids)) - for outputs in model_inst.stream_infer( + async for outputs in model_inst.async_stream_infer( session_id, input_ids=input_ids, gen_config=GenerationConfig(max_new_tokens=output_seqlen, @@ -123,7 +123,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, prev = now # for pytorch engine to restart a session if isinstance(model_inst, EngineInstance): - model_inst.end(session_id) + await model_inst.async_end(session_id) assert output_seqlen <= n_token <= output_seqlen + 1, \ f'Error. session_id({session_id}) request {output_seqlen} ' \ f'tokens, but generate {n_token} tokens.\n' \ @@ -139,13 +139,12 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, # skip the first token latency per_token_latency_stats.append(_per_token_latency_stats[1:]) self.pbar.update(1) - res_queue.put((session_id, stats, per_token_latency_stats)) + res_queue.put_nowait((session_id, stats, per_token_latency_stats)) def process_request(self, requests, concurrency, temperature, top_p, top_k, stream_output): res_queue = Queue() req_queue = Queue() - threads = [] self.pbar = tqdm(total=len(requests)) @@ -157,18 +156,20 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, start = time.time() + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + # start threads + tasks = [] for i in range(concurrency): - t = Thread(target=self._inference, - args=(req_queue, res_queue, i, temperature, top_p, - top_k, stream_output), - daemon=True) - t.start() - threads.append(t) + task = self._inference(req_queue, res_queue, i, temperature, top_p, + top_k, stream_output) + tasks.append(task) + + async def _gather_tasks(tasks): + return await asyncio.gather(*tasks) - # wait for finish - for t in threads: - t.join() + event_loop.run_until_complete(_gather_tasks(tasks)) elapsed_time = time.time() - start @@ -333,7 +334,6 @@ def main(): block_size=args.cache_block_seq_len, max_batch_size=args.concurrency, tp=args.tp, - thread_safe=True, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, quant_policy=args.quant_policy, diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index d01d6fe9b..8261b869f 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -94,7 +94,10 @@ def forward( kv_seqlens = attn_metadata.kv_seqlens kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy - max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + if attn_metadata.is_decoding: + max_q_seqlen = 1 + else: + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) fill_max_q_seqlen = max_q_seqlen if attn_metadata.fill_seqlens is not None: fill_seqlens = attn_metadata.fill_seqlens diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 26b507e9d..b7a803a7a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -172,6 +172,7 @@ def __init__(self, self._start_loop() self._create_buffers() self.engine_instance = self.create_instance() + self._output_stream = torch.cuda.Stream() @classmethod def from_pretrained(cls, @@ -673,7 +674,8 @@ async def __long_context_single_forward(inputs): return ret def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor): + logits: torch.Tensor, stopped: torch.Tensor, + event: torch.cuda.Event): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, @@ -694,6 +696,11 @@ def __get_q_start_loc(): else: return seq_length.cumsum(0) - seq_length + with torch.cuda.stream(self._output_stream): + event.wait() + next_token_ids = next_token_ids.cpu() + stopped = stopped.cpu() + running = self._running is_run = [seq.status == MessageStatus.RUNNING for seq in running] stopped = stopped.tolist() @@ -755,6 +762,8 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') + if self.gpu_count == 1: + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() @@ -785,10 +794,11 @@ def __update_inputs(next_token_ids): next_token_ids, sampling_inputs.stop_words, num_appendable_ids) # send output - stopped = stopped.cpu() - finish = stopped.all().item() or (idx == loop_count - 1) + finish = (idx == loop_count - 1) finish = finish or _check_finish(self.scheduler, idx) - output = (next_token_ids.cpu(), logits, stopped) + event = torch.cuda.Event() + event.record() + output = (next_token_ids, logits, stopped, event) output_que.put_nowait((finish, output)) if finish: @@ -951,9 +961,9 @@ async def __step(): try: if isinstance(out, Exception): raise out - next_token_ids, logits, stopped = out + next_token_ids, logits, stopped, event = out step_outputs = self._make_infer_outputs( - next_token_ids, logits, stopped) + next_token_ids, logits, stopped, event) __send_resps(step_outputs) except Exception as e: raise e diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 54740a4fb..24cb336d7 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -21,10 +21,9 @@ def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor): def _process_bad_words_(scores: torch.Tensor, bad_words: torch.LongTensor, + mask: torch.BoolTensor, filter_value: float = -float('inf')): """process bad words.""" - mask = bad_words >= 0 - bad_words = bad_words.where(mask, 0) filtered_scores = scores.gather(1, bad_words) filtered_scores[mask] = filter_value scores.scatter_(1, bad_words, filtered_scores) @@ -127,7 +126,9 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, class SamplingInputs: temperature: torch.Tensor = None bad_words: torch.LongTensor = None + bad_mask: torch.BoolTensor = None stop_words: torch.LongTensor = None + stop_mask: torch.BoolTensor = None repetition_penalty: torch.Tensor = None top_k: torch.LongTensor = None top_p: torch.Tensor = None @@ -200,9 +201,11 @@ def __get_bad_words(bad_words): """get bad words.""" max_bw_len = max(len(bw) for bw in bad_words) if max_bw_len == 0: - return None + return None, None if all(len(bw) == max_bw_len for bw in bad_words): - return torch.tensor(bad_words) + ret = torch.tensor(bad_words) + mask = torch.ones_like(ret, dtype=bool) + return ret, mask ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) for idx, bw in enumerate(bad_words): bw_len = len(bw) @@ -210,7 +213,10 @@ def __get_bad_words(bad_words): continue bw = ret.new_tensor(bw) ret[idx, :bw_len] = bw - return ret + + mask = ret >= 0 + ret = ret.where(mask, 0) + return ret, mask __gather_params() @@ -221,8 +227,8 @@ def __get_bad_words(bad_words): temperature = torch.tensor(temperature) - bad_words = __get_bad_words(bad_words) - stop_words = __get_bad_words(stop_words) + bad_words, bad_mask = __get_bad_words(bad_words) + stop_words, stop_mask = __get_bad_words(stop_words) max_top_k = max(top_k) if min(top_k) <= 0: @@ -243,7 +249,9 @@ def __get_bad_words(bad_words): sampling_input = cls( temperature=temperature, bad_words=bad_words, + bad_mask=bad_mask, stop_words=stop_words, + stop_mask=stop_mask, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, @@ -326,12 +334,14 @@ def __call__(self, all_ids: torch.LongTensor, bad_words = sampling_inputs.bad_words if bad_words is not None: - scores = _process_bad_words_(scores, bad_words) + bad_mask = sampling_inputs.bad_mask + scores = _process_bad_words_(scores, bad_words, bad_mask) stop_words = sampling_inputs.stop_words if stop_words is not None: - stop_words = torch.where(self.ignore_eos[:, None], stop_words, -1) - scores = _process_bad_words_(scores, stop_words) + stop_mask = sampling_inputs.stop_mask + stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False) + scores = _process_bad_words_(scores, stop_words, stop_mask) scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 2877f5937..59d77f264 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -162,10 +162,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): self.model_config = model_config self.cache_config = cache_config - def get_block_numel(self): - """get block nelement.""" - raise NotImplementedError('Not implemented') - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -177,17 +173,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, """ raise NotImplementedError('Not implemented.') - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - raise NotImplementedError('Not implemented.') - def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" raise NotImplementedError('Not implemented.') @@ -255,11 +240,6 @@ def _build_model(self, device=device) return patched_model - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): cache_swapping(self.cache_engine, @@ -274,21 +254,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - output = self._forward_impl(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) - self.stream.synchronize() - return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -301,8 +266,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - await asyncio.get_event_loop().run_in_executor(None, - self.stream.synchronize) + await asyncio.sleep(0) + while not self.stream.query(): + await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): @@ -688,11 +654,6 @@ def _build_model( return model, cache_engine, cache_config - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" @@ -713,21 +674,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - output = self._forward_impl(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) - self.stream.synchronize() - return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -740,8 +686,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - await asyncio.get_event_loop().run_in_executor(None, - self.stream.synchronize) + await asyncio.sleep(0) + while not self.stream.query(): + await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): diff --git a/lmdeploy/pytorch/kernels/cuda/activation.py b/lmdeploy/pytorch/kernels/cuda/activation.py index 2533840a9..9a00e7354 100644 --- a/lmdeploy/pytorch/kernels/cuda/activation.py +++ b/lmdeploy/pytorch/kernels/cuda/activation.py @@ -7,10 +7,8 @@ TRITON_VERSION = version.parse(triton.__version__) if TRITON_VERSION >= version.parse('3.0.0'): - fast_expf = tl.math.exp else: - tanh = tl.math.tanh fast_expf = tl.math.fast_expf @@ -26,63 +24,29 @@ def _silu_and_mul_kernel( BLOCK_SIZE_N: tl.constexpr, ): """silu and mul kernel.""" - m_id = tl.program_id(0) + n_block_id = tl.program_id(0) + m_id = tl.program_id(1) up_ptr = gateup_ptr + N * stride_gun - offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - for _ in range(0, N, BLOCK_SIZE_N): - gate = tl.load(gate_ptrs).to(tl.float32) - up = tl.load(up_ptrs).to(tl.float32) - - gate = gate / (1 + fast_expf(-gate)) - out = gate * up - - tl.store(out_ptrs, out) - - gate_ptrs += BLOCK_SIZE_N * stride_gun - up_ptrs += BLOCK_SIZE_N * stride_gun - out_ptrs += BLOCK_SIZE_N * stride_on - - -@triton.jit -def _silu_and_mul_no_align_kernel( - gateup_ptr, - out_ptr, - N: tl.constexpr, - stride_gum: tl.constexpr, - stride_gun: tl.constexpr, - stride_om: tl.constexpr, - stride_on: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - """silu and mul kernel.""" - m_id = tl.program_id(0) - - up_ptr = gateup_ptr + N * stride_gun - - offs_n = tl.arange(0, BLOCK_SIZE_N) - gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun - up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun - out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - - for n in range(0, N, BLOCK_SIZE_N): - mask = n + offs_n < N - gate = tl.load(gate_ptrs, mask=mask).to(tl.float32) - up = tl.load(up_ptrs, mask=mask).to(tl.float32) - - gate = gate / (1 + fast_expf(-gate)) - out = gate * up + if N % BLOCK_SIZE_N == 0: + mask = None + else: + mask = offs_n < N + gate = tl.load(gate_ptrs, mask=mask) + up = tl.load(up_ptrs, mask=mask) + gate = gate.to(tl.float32) + up = up.to(tl.float32) - tl.store(out_ptrs, out, mask=mask) + gate = gate / (1 + fast_expf(-gate)) + out = gate * up - gate_ptrs += BLOCK_SIZE_N * stride_gun - up_ptrs += BLOCK_SIZE_N * stride_gun - out_ptrs += BLOCK_SIZE_N * stride_on + tl.store(out_ptrs, out, mask=mask) def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): @@ -96,31 +60,22 @@ def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): out = gate_up.new_empty(out_shape) BLOCK_SIZE_N = triton.next_power_of_2(N) - BLOCK_SIZE_N = min(BLOCK_SIZE_N, 1024) + BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512) num_warps = 4 - num_stages = 2 - grid = (M, ) - if N % BLOCK_SIZE_N == 0: - _silu_and_mul_kernel[grid](gate_up, - out, - N, - stride_gum=gate_up.stride(0), - stride_gun=gate_up.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - BLOCK_SIZE_N=BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages) - else: - _silu_and_mul_no_align_kernel[grid](gate_up, - out, - N, - stride_gum=gate_up.stride(0), - stride_gun=gate_up.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - BLOCK_SIZE_N=BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages) + num_stages = 1 + grid = ( + triton.cdiv(N, BLOCK_SIZE_N), + M, + ) + _silu_and_mul_kernel[grid](gate_up, + out, + N, + stride_gum=gate_up.stride(0), + stride_gun=gate_up.stride(1), + stride_om=out.stride(0), + stride_on=out.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages) return out diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index 9e14dc6a0..f9d5f2f17 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -4,35 +4,9 @@ import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func - - -@wrap_jit_func(type_hint=dict( - Q=Tensor, - K=Tensor, - COS=Tensor, - SIN=Tensor, - POS=Tensor, - Q_EMB=Tensor, - K_EMB=Tensor, - seq_len=int, - stride_qs=int, - stride_qh=int, - stride_qd=int, - stride_ks=int, - stride_kh=int, - stride_kd=int, - stride_qes=int, - stride_qeh=int, - stride_qed=int, - stride_kes=int, - stride_keh=int, - stride_ked=int, - half_size=torch.int32, - BLOCK=torch.int32, - BLOCK_QH=torch.int32, - BLOCK_N=torch.int32, -)) +from .triton_utils import get_kernel_meta + + @triton.jit(do_not_specialize=('seq_len', )) def apply_rotary_pos_emb_qk_kernel( Q, @@ -60,8 +34,8 @@ def apply_rotary_pos_emb_qk_kernel( BLOCK_N: tl.constexpr, ): """apply rotary on key AND query kernel.""" - seq_block_id = tl.program_id(0) - head_id = tl.program_id(1) + seq_block_id = tl.program_id(1) + head_id = tl.program_id(0) pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK) pos_mask = pos_offset < seq_len @@ -158,10 +132,13 @@ def apply_rotary_pos_emb(q: Tensor, num_heads_q = q.size(-2) num_heads_k = k.size(-2) num_warps = 4 - num_stages = 4 + num_stages = 1 kernel_meta = get_kernel_meta(q) - grid = [triton.cdiv(seq_len, BLOCK), num_heads_q + num_heads_k] + grid = [ + num_heads_q + num_heads_k, + triton.cdiv(seq_len, BLOCK), + ] apply_rotary_pos_emb_qk_kernel[grid](q, k, cos, diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index 9ef614fad..93bd89f48 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Literal -import torch import triton import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func +from .triton_utils import get_kernel_meta @triton.jit @@ -38,37 +37,6 @@ def _quant_int4(val1, val2): return q_val, scales, zeros -@wrap_jit_func(type_hint=dict( - KStates=Tensor, - VStates=Tensor, - KCaches=Tensor, - VCaches=Tensor, - QStartLoc=Tensor, - QSeqLens=Tensor, - KVSeqLens=Tensor, - BlockOffsets=Tensor, - num_heads=torch.int32, - head_dim=torch.int32, - stride_kss=int, - stride_ksh=int, - stride_ksd=int, - stride_vss=int, - stride_vsh=int, - stride_vsd=int, - stride_kcn=int, - stride_kcb=int, - stride_kch=int, - stride_kcd=int, - stride_vcn=int, - stride_vcb=int, - stride_vch=int, - stride_vcd=int, - stride_boff=int, - BLOCK=torch.int32, - BLOCK_D=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_H=torch.int32, -)) @triton.jit def _fill_kv_cache_kernel( KStates, @@ -79,7 +47,7 @@ def _fill_kv_cache_kernel( QSeqLens, KVSeqLens, BlockOffsets, - num_heads: tl.constexpr, + is_decoding: tl.constexpr, head_dim: tl.constexpr, head_dim_v: tl.constexpr, stride_kss, @@ -100,108 +68,70 @@ def _fill_kv_cache_kernel( BLOCK: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_DV: tl.constexpr, - BLOCK_H: tl.constexpr, ): """fill kv cache kernel.""" - batch_id = tl.program_id(0) + batch_id = tl.program_id(2) + head_id = tl.program_id(0) block_id = tl.program_id(1) - # initialize - h_off = tl.arange(0, BLOCK_H) - d_off = tl.arange(0, BLOCK_D) - q_startloc = tl.load(QStartLoc + batch_id) q_seqlen = tl.load(QSeqLens + batch_id) kv_seqlen = tl.load(KVSeqLens + batch_id) history_seqlen = kv_seqlen - q_seqlen - block0_first_tokenloc = history_seqlen % BLOCK - - state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc, - 0) - kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id - kv_block_id = min(kv_block_id, stride_boff - 1) - block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) + kv_block_id = history_seqlen // BLOCK + block_id - cur_startloc = q_startloc + state_token_offset - ks_ptr = KStates + cur_startloc * stride_kss - vs_ptr = VStates + cur_startloc * stride_vss + if kv_seqlen <= 0: + return - kc_ptr = KCaches + block_off * stride_kcn - vc_ptr = VCaches + block_off * stride_vcn + if kv_block_id * BLOCK >= kv_seqlen: + return - c_first_tokenloc = block0_first_tokenloc - if block_id != 0: - c_first_tokenloc *= 0 - c_last_tokenloc = tl.minimum( - BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK) + if is_decoding: + page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32) + kv_mask = tl.full((1, ), 1, dtype=tl.int1) + q_offs = tl.full((1, ), q_startloc, dtype=tl.int32) + else: + page_offs = tl.arange(0, BLOCK) + kv_offs = kv_block_id * BLOCK + page_offs + kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen) + token_off = q_startloc + kv_block_id * BLOCK - history_seqlen + q_offs = token_off + page_offs - for bidx in range(c_first_tokenloc, c_last_tokenloc): - sidx = bidx - c_first_tokenloc - mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim) - k = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh + - d_off[None, :] * stride_ksd, - mask=mask) - tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch + - d_off[None, :] * stride_kcd, - k, - mask=mask) + block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) - if BLOCK_DV > 0: - dv_off = tl.arange(0, BLOCK_DV) - maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < - head_dim_v) - v = tl.load(vs_ptr + sidx * stride_vss + - h_off[:, None] * stride_vsh + - dv_off[None, :] * stride_vsd, - mask=maskv) - tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch + - dv_off[None, :] * stride_vcd, - v, - mask=maskv) + d_off = tl.arange(0, BLOCK_D) + mask_ks = kv_mask[:, None] + mask_kc = mask_ks & (d_off[None, :] < head_dim) + d_off = d_off % head_dim + + ks_ptr = KStates + head_id * stride_ksh + ks_ptrs = ks_ptr + q_offs[:, + None] * stride_kss + d_off[None, :] * stride_ksd + kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch + kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[ + None, :] * stride_kcd + + if BLOCK_DV > 0: + dv_off = tl.arange(0, BLOCK_DV) + mask_vs = kv_mask[:, None] + mask_vc = mask_vs & (dv_off[None, :] < head_dim_v) + dv_off = dv_off % head_dim_v + vs_ptr = VStates + head_id * stride_vsh + vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[ + None, :] * stride_vsd + vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch + vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[ + None, :] * stride_vcd + + k = tl.load(ks_ptrs, mask=mask_ks) + if BLOCK_DV > 0: + v = tl.load(vs_ptrs, mask=mask_vs) + tl.store(kc_ptrs, k, mask=mask_kc) + if BLOCK_DV > 0: + tl.store(vc_ptrs, v, mask=mask_vc) -@wrap_jit_func(type_hint=dict( - KStates=Tensor, - VStates=Tensor, - KCaches=Tensor, - VCaches=Tensor, - KScalesZeros=Tensor, - VScalesZeros=Tensor, - QStartLoc=Tensor, - QSeqLens=Tensor, - KVSeqLens=Tensor, - BlockOffsets=Tensor, - num_heads=torch.int32, - head_dim=torch.int32, - stride_kss=int, - stride_ksh=int, - stride_ksd=int, - stride_vss=int, - stride_vsh=int, - stride_vsd=int, - stride_kcn=int, - stride_kcb=int, - stride_kch=int, - stride_kcd=int, - stride_vcn=int, - stride_vcb=int, - stride_vch=int, - stride_vcd=int, - stride_kszn=int, - stride_kszb=int, - stride_kszh=int, - stride_kszd=int, - stride_vszn=int, - stride_vszb=int, - stride_vszh=int, - stride_vszd=int, - stride_boff=int, - BLOCK=torch.int32, - BLOCK_D=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_H=torch.int32, -)) @triton.jit def _fill_kv_cache_quant_kernel( KStates, @@ -394,15 +324,19 @@ def fill_kv_cache(k_states: Tensor, num_heads = k_caches.size(h_dim) head_dim = k_caches.size(d_dim) head_dim_v = v_states.size(-1) - max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 + if max_q_seq_length == 1: + max_num_blocks = 1 + else: + max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 BLOCK = block_size BLOCK_H = triton.next_power_of_2(num_heads) BLOCK_D = triton.next_power_of_2(head_dim) BLOCK_DV = triton.next_power_of_2(head_dim_v) - grid = [batch_size, max_num_blocks] kernel_meta = get_kernel_meta(k_states) if quant_policy == 0: + grid = [num_heads, max_num_blocks, batch_size] + is_decoding = max_num_blocks == 1 _fill_kv_cache_kernel[grid]( k_states, v_states, @@ -412,7 +346,7 @@ def fill_kv_cache(k_states: Tensor, q_seq_length, kv_seq_length, block_offsets, - num_heads=num_heads, + is_decoding=is_decoding, head_dim=head_dim, head_dim_v=head_dim_v, stride_kss=k_states.stride(-3), @@ -433,12 +367,12 @@ def fill_kv_cache(k_states: Tensor, BLOCK=BLOCK, BLOCK_D=BLOCK_D, BLOCK_DV=BLOCK_DV, - BLOCK_H=BLOCK_H, num_warps=4, num_stages=3, **kernel_meta, ) else: + grid = [batch_size, max_num_blocks] _fill_kv_cache_quant_kernel[grid]( k_states, v_states, diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 149376e4b..74d090a9a 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -70,15 +70,14 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int64, device=device) - input_buffers['q_start_loc'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) - input_buffers['q_seqlens'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) - input_buffers['kv_seqlens'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) + + input_buffers['qkv_lens'] = torch.zeros(3, + max_batches, + dtype=torch.int64, + device=device) + input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0] + input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1] + input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2] input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device) @@ -111,13 +110,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_buffers['position_ids'][:, :num_tokens] = position_ids input_buffers[ 'block_offsets'][:batch_size, :num_blocks] = block_offsets - if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): - input_buffers['q_seqlens'].zero_() - input_buffers['q_seqlens'][:batch_size] = q_seqlens - if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): - input_buffers['kv_seqlens'].zero_() - input_buffers['kv_seqlens'][:batch_size] = kv_seqlens - input_buffers['q_start_loc'][:batch_size] = q_start_loc + + qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens)) + input_buffers['qkv_lens'].zero_() + input_buffers['qkv_lens'][:, :batch_size] = qkv if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) if 'inputs_embeds' not in input_buffers: diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index 5c5fdbdc1..69c831541 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -35,8 +35,9 @@ def test_process_bad_words(): [4, 4], [-1, -1], ]) + mask = bad_words >= 0 - out_scores = _process_bad_words_(scores, bad_words) + out_scores = _process_bad_words_(scores, bad_words.where(mask, 0), mask) for score, bw in zip(out_scores, bad_words): bw = bw.tolist() From 0dedd73e5727776e2392b6f7256e0f66d0c48c8b Mon Sep 17 00:00:00 2001 From: q yao Date: Tue, 3 Dec 2024 14:44:22 +0800 Subject: [PATCH 13/14] fix the logic to verify whether AutoAWQ has been successfully installed (#2844) --- lmdeploy/pytorch/backends/cuda/awq_modules.py | 2 -- lmdeploy/pytorch/backends/cuda/op_backend.py | 6 +++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index f3cbf8bee..8159bbf55 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -53,8 +53,6 @@ class AwqLinearW4A16Impl(LinearW4A16Impl): def __init__(self, in_features: int, out_features: int, w_bit: int, group_size: int): - from awq.modules.linear.gemm import AWQ_INSTALLED - assert AWQ_INSTALLED self.in_features = in_features self.out_features = out_features self.w_bit = w_bit diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 3e7fc2372..d796f8e19 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -48,7 +48,11 @@ def get_layer_impl_builder(cls, layer_type: OpType): from .activation import TritonSiluAndMulBuilder return TritonSiluAndMulBuilder elif layer_type == OpType.LinearW4A16: - from awq.modules.linear.gemm import AWQ_INSTALLED + try: + from awq.modules.linear.gemm import awq_ext # noqa: F401 + AWQ_INSTALLED = True + except Exception: + AWQ_INSTALLED = False if AWQ_INSTALLED: from .awq_modules import AwqLinearW4A16Builder return AwqLinearW4A16Builder From efa8ac032005091a17f6c4555917d400c44486ba Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Tue, 3 Dec 2024 14:45:57 +0800 Subject: [PATCH 14/14] check whether backend_config is None or not before accessing its attr (#2848) --- lmdeploy/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 2b4204a53..42b7c6e4c 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -71,7 +71,7 @@ def pipeline(model_path: str, task, pipeline_class = get_task(model_path) if task == 'vlm': - if backend_config.enable_prefix_caching: + if backend_config and backend_config.enable_prefix_caching: backend_config.enable_prefix_caching = False logger.warning('VLM does not support prefix caching.')