Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into rope
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 9, 2024
2 parents 832ade3 + 4f7e50b commit 2beb46f
Show file tree
Hide file tree
Showing 49 changed files with 704 additions and 5,942 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ For detailed inference benchmarks in more devices and more settings, please refe
<li>Qwen1.5 (0.5B - 110B)</li>
<li>Qwen1.5 - MoE (0.5B - 72B)</li>
<li>Qwen2 (0.5B - 72B)</li>
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand All @@ -136,6 +138,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
<li>Mistral (7B)</li>
<li>DeepSeek-MoE (16B)</li>
<li>DeepSeek-V2 (16B, 236B)</li>
<li>DeepSeek-V2.5 (236B)</li>
<li>Mixtral (8x7B, 8x22B)</li>
<li>Gemma (2B - 7B)</li>
<li>Dbrx (132B)</li>
Expand Down
3 changes: 3 additions & 0 deletions README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
<li>Qwen1.5 (0.5B - 110B)</li>
<li>Qwen1.5 - MoE (0.5B - 72B)</li>
<li>Qwen2 (0.5B - 72B)</li>
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand All @@ -133,6 +135,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
<li>Mistral (7B)</li>
<li>DeepSeek-MoE (16B)</li>
<li>DeepSeek-V2 (16B, 236B)</li>
<li>DeepSeek-V2.5 (236B)</li>
<li>Mixtral (8x7B, 8x22B)</li>
<li>Gemma (2B - 7B)</li>
<li>Dbrx (132B)</li>
Expand Down
3 changes: 3 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
<li>Qwen1.5 (0.5B - 110B)</li>
<li>Qwen1.5 - MoE (0.5B - 72B)</li>
<li>Qwen2 (0.5B - 72B)</li>
<li>Qwen2-MoE (57BA14B)</li>
<li>Qwen2.5 (0.5B - 32B)</li>
<li>Baichuan (7B)</li>
<li>Baichuan2 (7B-13B)</li>
<li>Code Llama (7B - 34B)</li>
Expand All @@ -137,6 +139,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
<li>Mistral (7B)</li>
<li>DeepSeek-MoE (16B)</li>
<li>DeepSeek-V2 (16B, 236B)</li>
<li>DeepSeek-V2.5 (236B)</li>
<li>Mixtral (8x7B, 8x22B)</li>
<li>Gemma (2B - 7B)</li>
<li>Dbrx (132B)</li>
Expand Down
38 changes: 19 additions & 19 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import asyncio
import csv
import json
import os
import random
import time
from queue import Queue
from threading import Thread
from typing import List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -86,23 +86,23 @@ def __init__(self, model_path: str,
self.csv = csv
self.pbar = None

def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
temperature: float, top_p: float, top_k: int,
stream_output: bool):
async def _inference(self, req_queue: Queue, res_queue: Queue,
session_id: int, temperature: float, top_p: float,
top_k: int, stream_output: bool):
model_inst = self.tm_model.create_instance()
stats = []
# get each generated token's latency
per_token_latency_stats = []
for prompt, input_seqlen, output_seqlen in iter(
req_queue.get, [None, None, None]):
req_queue.get_nowait, [None, None, None]):
_per_token_latency_stats = [0] * (output_seqlen + 1)
prev = time.perf_counter()
n_prev_token = 0

input_ids = self.tokenizer(prompt).input_ids
state = DetokenizeState(len(input_ids))

for outputs in model_inst.stream_infer(
async for outputs in model_inst.async_stream_infer(
session_id,
input_ids=input_ids,
gen_config=GenerationConfig(max_new_tokens=output_seqlen,
Expand All @@ -123,7 +123,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
prev = now
# for pytorch engine to restart a session
if isinstance(model_inst, EngineInstance):
model_inst.end(session_id)
await model_inst.async_end(session_id)
assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens.\n' \
Expand All @@ -139,13 +139,12 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
# skip the first token latency
per_token_latency_stats.append(_per_token_latency_stats[1:])
self.pbar.update(1)
res_queue.put((session_id, stats, per_token_latency_stats))
res_queue.put_nowait((session_id, stats, per_token_latency_stats))

def process_request(self, requests, concurrency, temperature, top_p, top_k,
stream_output):
res_queue = Queue()
req_queue = Queue()
threads = []

self.pbar = tqdm(total=len(requests))

Expand All @@ -157,18 +156,20 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k,

start = time.time()

event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

# start threads
tasks = []
for i in range(concurrency):
t = Thread(target=self._inference,
args=(req_queue, res_queue, i, temperature, top_p,
top_k, stream_output),
daemon=True)
t.start()
threads.append(t)
task = self._inference(req_queue, res_queue, i, temperature, top_p,
top_k, stream_output)
tasks.append(task)

async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)

# wait for finish
for t in threads:
t.join()
event_loop.run_until_complete(_gather_tasks(tasks))

elapsed_time = time.time() - start

Expand Down Expand Up @@ -333,7 +334,6 @@ def main():
block_size=args.cache_block_seq_len,
max_batch_size=args.concurrency,
tp=args.tp,
thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
quant_policy=args.quant_policy,
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ARG PYTHON_VERSION=3.10
ARG TORCH_VERSION=2.3.0
ARG TORCHVISION_VERSION=0.18.0

RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl &&\
RUN apt-get update -y && apt-get install -y software-properties-common wget vim git curl openssh-server ssh sudo &&\
curl https://sh.rustup.rs -sSf | sh -s -- -y &&\
add-apt-repository ppa:deadsnakes/ppa -y && apt-get update -y && apt-get install -y --no-install-recommends \
ninja-build rapidjson-dev libgoogle-glog-dev gdb python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
Expand Down
6 changes: 6 additions & 0 deletions docs/en/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```

Please check [supported_models](../../supported_models/supported_models.md) before use this feature.

### int8 KV-cache Quantization

Ascend backend has supported offline int8 KV-cache Quantization on eager mode.

Please refer this [doc](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md) for details.
59 changes: 33 additions & 26 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes |
| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
| Mistral | 7B | LLM | Yes | Yes | Yes | No |
| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
Expand All @@ -29,7 +33,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes |
| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
Expand All @@ -41,7 +45,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
"-" means not verified yet.

```{note}
The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
* The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
* When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference
```

## PyTorchEngine on CUDA Platform
Expand All @@ -51,47 +56,49 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | No | - |
| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | No | - |
| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | No | - |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | - |
| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
| YI | 6B - 34B | LLM | Yes | Yes | Yes | No | Yes |
| Mistral | 7B | LLM | Yes | Yes | Yes | No | No |
| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | No | Yes |
| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | No | Yes |
| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | No | Yes |
| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | No | Yes |
| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | No | - |
| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | No | - |
| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | No | - |
| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | No | - |
| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | No | - |
| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | No | - |
| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | No | - |
| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | No | - |
| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No |
| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | No | - |
| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | No | - |
| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | No | - |
| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | No | - |
| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |

```{note}
* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
Expand Down
6 changes: 6 additions & 0 deletions docs/zh_cn/get_started/ascend/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ lmdeploy lite auto_awq $HF_MODEL --work-dir $WORK_DIR --device npu
```

支持的模型列表请参考[支持的模型](../../supported_models/supported_models.md)

### int8 KV-cache 量化

昇腾后端现在支持了在eager模式下的离线int8 KV-cache量化。

详细使用方式请请参考这篇[文章](https://github.com/DeepLink-org/dlinfer/blob/main/docs/quant/ascend_kv_quant.md)
Loading

0 comments on commit 2beb46f

Please sign in to comment.