From f0dabeeaed98109e561212ee837f68f8b4e9f8e2 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Sat, 2 Mar 2024 20:34:48 +0800 Subject: [PATCH 01/17] Fix `None` session_len (#1230) --- lmdeploy/serve/async_engine.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 53402b133c..84b588c531 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -59,7 +59,7 @@ def _config_model_name(config): raise ArgumentError(None, f'Please set model_name for {model_path}') else: - logger.warning(f'Best matched chat template name: {model_name}') + logger.info(f'matched chat template name: {model_name}') return model_name @@ -111,9 +111,10 @@ def __init__(self, chat_template_config: Optional[ChatTemplateConfig] = None, tp: int = 1, **kwargs) -> None: - logger.info(f'AsyncEngine init with backend={backend}, backend_config' - f'={backend_config}, chat_template_config=' - f'{chat_template_config}') + logger.info( + f'input backend={backend}, backend_config={backend_config}') + logger.info(f'input chat_template_config={chat_template_config}') + self.model_name = deduce_a_name(model_path, model_name, backend_config, chat_template_config) # build chat template config @@ -122,6 +123,7 @@ def __init__(self, elif chat_template_config.model_name is None: chat_template_config.model_name = self.model_name self.chat_template = chat_template_config.chat_template + # prevent bc for k in list(kwargs.keys()): if hasattr(chat_template_config, k): @@ -129,26 +131,26 @@ def __init__(self, 'chat_template_config instead') v = kwargs.pop(k) setattr(chat_template_config, k, v) + logger.info(f'updated chat_template_onfig={chat_template_config}') # build backend engine if backend == 'turbomind': - logger.info('Running turbomind engine for pipeline.') self._build_turbomind(model_path=model_path, backend_config=backend_config, chat_template_config=chat_template_config, tp=tp, **kwargs) elif backend == 'pytorch': - logger.info('Running pytorch engine for pipeline.') self._build_pytorch(model_path=model_path, backend_config=backend_config, **kwargs) else: raise ValueError(f'unsupported backend {backend}') + logger.info(f'updated backend_config={self.backend_config}') + # parameters for member functions - self.session_len = backend_config.session_len - self.backend_config = backend_config + self.session_len = self.backend_config.session_len self.stop_words = _stop_words(self.chat_template.stop_words, self.engine.tokenizer) if self.stop_words is not None: @@ -187,6 +189,7 @@ def _build_turbomind( engine_config=backend_config, chat_template_config=chat_template_config, **kwargs) + self.backend_config = backend_config def _build_pytorch( self, @@ -205,6 +208,7 @@ def _build_pytorch( backend_config.session_len = self.chat_template.session_len self.engine = Engine(model_path=model_path, engine_config=backend_config) + self.backend_config = backend_config def __call__(self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], From 79ac87b9fae4d2696d875444e4ed0b615309b8c0 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Sun, 3 Mar 2024 21:07:11 +0800 Subject: [PATCH 02/17] fix multinomial sampling (#1228) * fix * fix repe penal --------- Co-authored-by: grimoire --- lmdeploy/pytorch/engine/logits_process.py | 2 ++ lmdeploy/pytorch/kernels/multinomial_sampling.py | 6 +++--- tests/pytorch/kernel/test_multinomial_sampling.py | 9 ++++++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 4493dc432f..bb2895bb5e 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -12,6 +12,7 @@ def _process_temperature(scores: torch.Tensor, temperature: torch.Tensor, inplace: bool = True): """process temperature.""" + temperature = temperature.to(scores.dtype) if not inplace: scores = scores / temperature[:, None] else: @@ -42,6 +43,7 @@ def _process_repetition_penalty(scores: torch.Tensor, inplace: bool = True): """process repetition penalty.""" score = torch.gather(scores, 1, input_ids) + penalty = penalty.to(score.dtype) score = torch.where(score < 0, score * penalty[:, None], score / penalty[:, None]) if not inplace: diff --git a/lmdeploy/pytorch/kernels/multinomial_sampling.py b/lmdeploy/pytorch/kernels/multinomial_sampling.py index 9ad29c773b..476787fb8c 100644 --- a/lmdeploy/pytorch/kernels/multinomial_sampling.py +++ b/lmdeploy/pytorch/kernels/multinomial_sampling.py @@ -22,7 +22,7 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, samp = tl.rand(seed, offset)[:, None] acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty) - output = tl.full((BLOCK, ), -1, dtype=tl.int64) + output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty) for b_idx in range(0, num_tokens, BLOCK_N): s_off = b_idx + n_off @@ -31,8 +31,8 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, s_off[None, :] * stride_st, mask=s_mask, other=0.0) - cum_scores = acc[:, None] + tl.cumsum(scores, 1) - acc += tl.sum(scores, 1) + cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype) + acc += tl.sum(scores, 1).to(acc.dtype) pre_cum_scores = cum_scores - scores valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores) diff --git a/tests/pytorch/kernel/test_multinomial_sampling.py b/tests/pytorch/kernel/test_multinomial_sampling.py index 4512a34879..f0f594dde6 100644 --- a/tests/pytorch/kernel/test_multinomial_sampling.py +++ b/tests/pytorch/kernel/test_multinomial_sampling.py @@ -19,10 +19,15 @@ def batch_size(self, select_ids): yield len(select_ids) @pytest.fixture - def scores(self, num_tokens, batch_size, select_ids): + def dtype(self, request): + yield request.param + + @pytest.fixture + def scores(self, num_tokens, batch_size, select_ids, dtype): ret = torch.zeros(batch_size, num_tokens).cuda() batch_ids = torch.arange(batch_size).cuda() ret[batch_ids, select_ids] = 1 + ret = ret.to(dtype) yield ret @pytest.fixture @@ -45,6 +50,8 @@ def gt(self, batch_size, select_ids, indices): batch_ids = torch.arange(batch_size).cuda() yield indices[batch_ids, select_ids] + @pytest.mark.parametrize('dtype', + [torch.float32, torch.half, torch.bfloat16]) @pytest.mark.parametrize(['num_tokens', 'select_ids'], [ (8, (4, 2) * 30), (200, (50, 150)), From 4f6bb7803073d5f4b7fef8d9c1c74a983755e3bb Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Sun, 3 Mar 2024 22:42:41 +0800 Subject: [PATCH 03/17] Async torch engine (#1206) * add multinomial sampling kernel * add multinomial ut * update sampling * fix block offsets * solve conflict * not copy bw * async engine * add threadsafe * fix ut * recovery doc * fix run_until_complete * fix chat * fix chat * fix async-engine --- benchmark/profile_generation.py | 3 +- benchmark/profile_throughput.py | 3 +- lmdeploy/messages.py | 2 + lmdeploy/pytorch/engine/engine.py | 431 +++++++++++------- lmdeploy/pytorch/engine/model_agent.py | 113 ++++- lmdeploy/pytorch/engine/request.py | 491 ++++++++++++++++----- lmdeploy/pytorch/messages.py | 2 +- lmdeploy/serve/async_engine.py | 24 +- lmdeploy/serve/gradio/turbomind_coupled.py | 6 +- lmdeploy/serve/openai/api_server.py | 22 +- lmdeploy/serve/qos_engine/qos_engine.py | 4 +- lmdeploy/turbomind/turbomind.py | 10 + tests/pytorch/engine/test_request.py | 50 ++- 13 files changed, 845 insertions(+), 316 deletions(-) diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py index 31a9f0364d..788d9ba9fc 100644 --- a/benchmark/profile_generation.py +++ b/benchmark/profile_generation.py @@ -397,7 +397,8 @@ def main(): engine_config = PytorchEngineConfig( cache_max_entry_count=args.cache_max_entry_count, session_len=session_len, - tp=args.tp) + tp=args.tp, + thread_safe=True) gen_config = EngineGenerationConfig( top_k=args.top_k, top_p=args.top_p, diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index d96614929c..06c3706239 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -307,7 +307,8 @@ def main(): session_len=args.session_len, cache_max_entry_count=args.cache_max_entry_count, max_batch_size=args.concurrency, - tp=args.tp) + tp=args.tp, + thread_safe=True) engine = Engine(args.model_path, engine_config, csv=args.csv) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 7792e37aa3..bbf1a9c395 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -158,6 +158,7 @@ class PytorchEngineConfig: would be allocate according to current environment. adapters (dict): The path configs to lora adapters. max_prefill_token_num (int): tokens per iteration. + thread_safe (bool): thread safe engine instance. download_dir (str): Directory to download and load the weights, default to the default cache directory of huggingface. revision (str): The specific model version to use. @@ -176,6 +177,7 @@ class PytorchEngineConfig: num_gpu_blocks: int = 0 adapters: Dict[str, str] = None max_prefill_token_num: int = 8192 + thread_safe: bool = False download_dir: str = None revision: str = None diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index c2956b5804..34a6679458 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import os -import time from dataclasses import dataclass -from queue import Queue -from threading import Thread from typing import Any, Dict, List import torch @@ -21,7 +19,8 @@ from ..paging import Scheduler from .logits_process import FusedLogitsProcessor, SamplingInputs from .model_agent import AutoModelAgent, ModelInputs -from .request import Request, RequestManager, RequestType, Response +from .request import (Request, RequestManager, RequestSender, RequestType, + Response) logger = get_logger('lmdeploy') @@ -47,21 +46,6 @@ class InferOutput: logits: torch.Tensor = None -def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None): - """check if response has state.""" - if isinstance(state, ResponseType): - state = [state] - ret = resp.type in state - if not ret and warning_msg is not None: - logger.warning(warning_msg) - return ret - - -def _check_resp_success(resp: Response, warning_msg: str = None): - """check if response success.""" - return _check_resp(resp, ResponseType.SUCCESS, warning_msg) - - def _paging_adapters(adapters: dict, model_agent: AutoModelAgent, scheduler: Scheduler): adapters = adapters or dict() @@ -89,6 +73,79 @@ def _get_adapter_ids(seqs: SeqList, adapters: AdapterList): return adapter_ids +def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None): + """check if response has state.""" + if isinstance(state, ResponseType): + state = [state] + ret = resp.type in state + if not ret and warning_msg is not None: + logger.warning(warning_msg) + return ret + + +def _check_resp_success(resp: Response, warning_msg: str = None): + """check if response success.""" + return _check_resp(resp, ResponseType.SUCCESS, warning_msg) + + +async def async_try_add_session(req_sender: RequestSender, session_id: int): + """Add new session. + + Args: + session_id (int): The session id to add. + """ + resp = await req_sender.async_send(RequestType.ADD_SESSION, + dict(session_id=session_id)) + _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], + (f'Can not add session {session_id} ' + f'with error: {resp.type}')) + + +async def async_end(req_sender: RequestSender, session_id: int): + """End the given session.""" + resp = await req_sender.async_send(RequestType.END_SESSION, + dict(session_id=session_id)) + _check_resp_success(resp, (f'Failed to end session: {session_id}. ' + f'Error: {resp.type}.')) + + +async def async_cancel(req_sender: RequestSender, session_id: int): + """Stop current streaming inference.""" + resp = await req_sender.async_send(RequestType.STOP_SESSION, + dict(session_id=session_id)) + _check_resp_success(resp, (f'Failed to cancel session: {session_id}. ' + f'Error: {resp.type}.')) + + +def try_add_session(req_sender: RequestSender, session_id: int): + """Add new session. + + Args: + session_id (int): The session id to add. + """ + resp = req_sender.send(RequestType.ADD_SESSION, + dict(session_id=session_id)) + _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], + (f'Can not add session {session_id} ' + f'with error: {resp.type}')) + + +def end(req_sender: RequestSender, session_id: int): + """End the given session.""" + resp = req_sender.send(RequestType.END_SESSION, + dict(session_id=session_id)) + _check_resp_success(resp, (f'Failed to end session: {session_id}. ' + f'Error: {resp.type}.')) + + +def cancel(req_sender: RequestSender, session_id: int): + """Stop current streaming inference.""" + resp = req_sender.send(RequestType.STOP_SESSION, + dict(session_id=session_id)) + _check_resp_success(resp, (f'Failed to cancel session: {session_id}. ' + f'Error: {resp.type}.')) + + class Engine: """The inference engine of lmdeploy pytorch. @@ -153,8 +210,8 @@ def __init__(self, self.req_manager = self._bind_request_manager() # create main thread - self.loop_threads = self._start_loop() - self.req_sender = self.req_manager.build_sender(self.loop_threads) + self._start_loop() + self.req_sender = self.req_manager.build_sender() self._create_buffers() self.tokenizer = Tokenizer(model_path) @@ -199,7 +256,7 @@ def _create_buffers(self): def _bind_request_manager(self): """bind request manager.""" - req_manager = RequestManager() + req_manager = RequestManager(self.engine_config.thread_safe) req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session) req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session) req_manager.bind_func(RequestType.END_SESSION, self._on_end_session) @@ -208,9 +265,7 @@ def _bind_request_manager(self): def _start_loop(self): """start loop.""" - loop_threads = Thread(target=self.loop, daemon=True) - loop_threads.start() - return loop_threads + return self.req_manager.start_loop(self.async_loop) def _on_add_session(self, reqs: Request, **kwargs): """on add session callback.""" @@ -320,27 +375,29 @@ def create_instance(self, cuda_stream_id=0): """ return EngineInstance(self) + async def async_add_session(self, session_id: int): + """Add new session.""" + return await async_try_add_session(self.req_sender, session_id) + def add_session(self, session_id: int): """Add new session.""" - resp = self.req_sender.send(RequestType.ADD_SESSION, - dict(session_id=session_id)) - _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], - (f'Can not add session {session_id} ' - f'with error: {resp.type}')) + return try_add_session(self.req_sender, session_id) - def stop_session(self, session_id: int): + async def async_stop_session(self, session_id: int): """Stop the given session.""" - resp = self.req_sender.send(RequestType.STOP_SESSION, - dict(session_id=session_id)) - _check_resp_success(resp, (f'Failed to cancel session: {session_id}. ' - f'Error: {resp.type}.')) + return await async_cancel(self.req_sender, session_id) - def end_session(self, session_id: int): + def stop_session(self, session_id: int): + """Add new session.""" + return cancel(self.req_sender, session_id) + + async def async_end_session(self, session_id: int): """End the given session.""" - resp = self.req_sender.send(RequestType.END_SESSION, - dict(session_id=session_id)) - _check_resp_success(resp, (f'Failed to end session: {session_id}. ' - f'Error: {resp.type}.')) + return await async_end(self.req_sender, session_id) + + def end_session(self, session_id: int): + """Add new session.""" + return end(self.req_sender, session_id) @torch.inference_mode() def create_model_inputs(self, messages: SeqList, adapters: AdapterList): @@ -465,8 +522,8 @@ def _check_session_len(msg, max_session_len): return True return False - def sampling_logits(self, logits: torch.Tensor, running: SeqList, - inputs: ModelInputs): + async def async_sampling_logits(self, logits: torch.Tensor, + running: SeqList, inputs: ModelInputs): """sampling logits.""" def _gather_history(seqs: SeqList, device: torch.device): @@ -499,8 +556,12 @@ def _gather_history(seqs: SeqList, device: torch.device): if sampling_inputs.repetition_penalty is not None: input_ids = _gather_history(running, split_logits.device) logits_processor = FusedLogitsProcessor(sampling_inputs) - logits = logits_processor(input_ids, split_logits) - next_token_ids = logits_processor.sampling(logits).cpu() + + with torch.inference_mode(), torch.cuda.stream(self.stream): + logits = logits_processor(input_ids, split_logits) + next_token_ids = logits_processor.sampling(logits) + self.stream.synchronize() + next_token_ids = next_token_ids.cpu() return next_token_ids, split_logits @@ -527,8 +588,8 @@ def _can_output_token(self, token: torch.Tensor, msg: SchedulerSequence): return True - def _model_forward(self, inputs: ModelInputs, swap_in_map: Dict, - swap_out_map: Dict): + async def _async_model_forward(self, inputs: ModelInputs, + swap_in_map: Dict, swap_out_map: Dict): """model forward.""" max_prefill_token_num = self.scheduler_config.max_prefill_token_num swap_done = False @@ -562,20 +623,18 @@ def get_logits(self): torch.cuda.synchronize() return self._out_logits - def __forward(inputs): + async def __forward(inputs): """forward.""" nonlocal swap_done, swap_in_map, swap_out_map if swap_done: - return self.model_agent.forward(inputs, - swap_in_map=dict(), - swap_out_map=dict()) + return await self.model_agent.async_forward( + inputs, swap_in_map=dict(), swap_out_map=dict()) else: swap_done = True - return self.model_agent.forward(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) + return await self.model_agent.async_forward( + inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - def __long_context_single_forward(inputs, index): + async def __long_context_single_forward(inputs, index): """one large sequence.""" new_input = inputs.slice(index, index + 1) max_seq_len = new_input.seq_length[0] @@ -584,17 +643,17 @@ def __long_context_single_forward(inputs, index): logits_gather = _LogitsGather(max_seq_len) for inp in new_inputs: - tmp_out = __forward(inp) + tmp_out = await __forward(inp) logits_gather.gather(tmp_out) tmp_out['logits'] = logits_gather.get_logits() return tmp_out - def __long_context_batched_forward(inputs, start, end): + async def __long_context_batched_forward(inputs, start, end): """batched.""" new_inputs = inputs.slice(start, end) - return __forward(new_inputs) + return await __forward(new_inputs) - def __long_context_forward(inputs): + async def __long_context_forward(inputs): """forward for long context.""" seq_len = inputs.seq_length max_seq_len = inputs.input_ids.size(1) @@ -607,11 +666,11 @@ def __long_context_forward(inputs): while idx < batch_size: slen = seq_len[idx] if token_count == 0 and slen > max_prefill_token_num: - tmp_out = __long_context_single_forward(inputs, idx) + tmp_out = await __long_context_single_forward(inputs, idx) logits_gather.gather(tmp_out) idx += 1 elif token_count + slen > max_prefill_token_num: - tmp_out = __long_context_batched_forward( + tmp_out = await __long_context_batched_forward( inputs, indices[0], idx) logits_gather.gather(tmp_out) indices = [] @@ -622,18 +681,18 @@ def __long_context_forward(inputs): idx += 1 if token_count > 0: - tmp_out = __long_context_batched_forward( + tmp_out = await __long_context_batched_forward( inputs, indices[0], idx) logits_gather.gather(tmp_out) tmp_out['logits'] = logits_gather.get_logits() return tmp_out if inputs.input_ids.numel() < max_prefill_token_num: - return __forward(inputs) + return await __forward(inputs) else: - return __long_context_forward(inputs) + return await __long_context_forward(inputs) - def step(self, is_prefill: bool, return_logits: bool = False): + async def async_step(self, is_prefill: bool, return_logits: bool = False): """one step inference. Used to perform streaming chat. Args: @@ -642,7 +701,6 @@ def step(self, is_prefill: bool, return_logits: bool = False): Returns: Dict[int, InferOutput]: The output of each session. """ - # schedule schedule_output = self.scheduler.schedule(is_prefill=is_prefill) @@ -656,14 +714,14 @@ def step(self, is_prefill: bool, return_logits: bool = False): inputs = self.create_model_inputs(running, adapters) # inference - output = self._model_forward(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) + output = await self._async_model_forward(inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) custom_outputs = output['custom_outputs'] logits = output['logits'] logits = logits[0] # [bs, seq, prob] -> [seq, prob] - next_token_ids, split_logits = self.sampling_logits( + next_token_ids, split_logits = await self.async_sampling_logits( logits, running, inputs) self.update_running(running, next_token_ids, custom_outputs) @@ -691,12 +749,12 @@ def step(self, is_prefill: bool, return_logits: bool = False): outputs[msg.session_id].logits = msg_logit return outputs - def batched_infer(self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: EngineGenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False): + async def async_batched_infer(self, + session_ids: List[int], + token_ids: List[List[int]] = None, + gen_config: EngineGenerationConfig = None, + adapter_names: List[str] = None, + keep_cache: bool = False): """Send inference request. Args: @@ -718,11 +776,11 @@ def batched_infer(self, else: adapter_names = [None for _ in range(batch_size)] - def _add_sessions(session_ids): + async def _add_sessions(session_ids): for session_id in session_ids: - self.add_session(session_id) + await self.async_add_session(session_id) - def _add_messages(session_ids, token_ids): + async def _add_messages(session_ids, token_ids): add_msgs = [] sampling_param = SamplingParam.from_gen_config(gen_config) for session_id, token_id, adapter_name in zip( @@ -733,12 +791,12 @@ def _add_messages(session_ids, token_ids): adapter_name=adapter_name) add_msgs.append(msg) req_types = [RequestType.ADD_MESSAGE] * batch_size - req_ids = self.req_sender.batched_send_async(req_types, - data=add_msgs) + req_ids = await self.req_sender.async_batched_send_async( + req_types, data=add_msgs) return req_ids - _add_sessions(session_ids) - req_ids = _add_messages(session_ids, token_ids) + await _add_sessions(session_ids) + req_ids = await _add_messages(session_ids, token_ids) # receive messages req_idx_map = dict(zip(req_ids, range(len(req_ids)))) @@ -746,12 +804,12 @@ def _add_messages(session_ids, token_ids): status = 0 finish_count = batch_size while finish_count: - if not self.loop_threads.is_alive(): + if not self.req_manager.is_loop_alive(): logger.error('Engine loop is not alive.') status = 1 break - resp = self.req_sender.recv_any() + resp = await self.req_sender.async_recv_any() if resp.req_id not in req_ids: continue idx = req_idx_map[resp.req_id] @@ -762,7 +820,7 @@ def _add_messages(session_ids, token_ids): token_ids += resp.data['token_ids'] if not keep_cache: session_id = session_ids[idx] - self.end_session(session_id=session_id) + await self.async_end_session(session_id=session_id) finish_count -= 1 else: logger.error(f'Unexpected response: {resp.type}') @@ -772,6 +830,33 @@ def _add_messages(session_ids, token_ids): output_token_len = [len(token_ids) for token_ids in output_token_ids] return (status, output_token_ids, output_token_len) + def batched_infer(self, + session_ids: List[int], + token_ids: List[List[int]] = None, + gen_config: EngineGenerationConfig = None, + adapter_names: List[str] = None, + keep_cache: bool = False): + """Send inference request. + + Args: + session_ids (List[int]): The session id. + token_ids (List[int]): The input token ids. + gen_config (EngineGenerationConfig): The sampling parameters. + adapter_names (List[str]): The name of the adapters. + keep_cache (bool): Keep kv cache after infer. + + Returns: + int: Error flags. 0 if success. + List[int]: The streaming output tokens. + int: The number of the output tokens. + """ + coro = self.async_batched_infer(session_ids=session_ids, + token_ids=token_ids, + gen_config=gen_config, + adapter_names=adapter_names, + keep_cache=keep_cache) + return self.req_sender.run_until_complete(coro) + def decode(self, prompt_token_ids: List[List[int]]): """Perform one step inference and get logits. @@ -815,42 +900,37 @@ def decode(self, prompt_token_ids: List[List[int]]): return logits - def loop(self): + async def async_loop(self): """Main loop of the engine. Each engine instance would communicate with the engine by queue. """ - send_resp_que = Queue() - def _send_resp(): + def _send_resp(step_tokens): """send response callback.""" - while True: - step_tokens = send_resp_que.get() - time.sleep(0.02) - for _, out in step_tokens.items(): - if out.finish: - resp_type = ResponseType.FINISH - else: - resp_type = ResponseType.SUCCESS - self.req_manager.response( - Response( - type=resp_type, - sender_id=out.sender_id, - req_id=out.req_id, - data=dict(token_ids=out.token_ids), - )) - - send_thread = Thread(target=_send_resp, daemon=True) - send_thread.start() + for _, out in step_tokens.items(): + if out.finish: + resp_type = ResponseType.FINISH + else: + resp_type = ResponseType.SUCCESS + self.req_manager.response( + Response( + type=resp_type, + sender_id=out.sender_id, + req_id=out.req_id, + data=dict(token_ids=out.token_ids), + )) + prefill_interval = self.scheduler_config.prefill_interval prefill_counter = prefill_interval while True: if not self.req_manager.has_requests( ) and not self.scheduler.has_unfinished(): - time.sleep(0.01) + await asyncio.sleep(0.01) continue + logger.debug('async_loop: RequestManager Step.') self.req_manager.step() # forward @@ -859,13 +939,17 @@ def _send_resp(): is_prefill = not prefill_counter or not has_running if is_prefill: prefill_counter = prefill_interval + logger.debug('async_loop: Engine Step - ' + f'prefilling: {is_prefill}') with torch.inference_mode(): - step_tokens: Dict[int, InferOutput] = self.step( - is_prefill=is_prefill) + step_tokens: Dict[int, + InferOutput] = await self.async_step( + is_prefill=is_prefill) prefill_counter -= 1 # send response - send_resp_que.put(step_tokens) + logger.debug('async_loop: Response.') + _send_resp(step_tokens) class EngineInstance: @@ -877,23 +961,27 @@ class EngineInstance: def __init__(self, engine: Engine): self.engine = engine - self.req_sender = engine.req_manager.build_sender(engine.loop_threads) + self.req_sender = engine.req_manager.build_sender() def __del__(self): """Destructor.""" self.engine.req_manager.senders.pop(self.req_sender.sender_id) + async def _async_try_add_session(self, session_id: int): + """Add new session. + + Args: + session_id (int): The session id to add. + """ + return await async_try_add_session(self.req_sender, session_id) + def _try_add_session(self, session_id: int): """Add new session. Args: session_id (int): The session id to add. """ - resp = self.req_sender.send(RequestType.ADD_SESSION, - dict(session_id=session_id)) - _check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT], - (f'Can not add session {session_id} ' - f'with error: {resp.type}')) + return try_add_session(self.req_sender, session_id) async def async_stream_infer(self, session_id: int, @@ -914,46 +1002,67 @@ async def async_stream_infer(self, List[int]: The streaming output tokens. int: The number of the output tokens. """ - import asyncio gen_config = gen_config or EngineGenerationConfig() - request_output_len = gen_config.max_new_tokens sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - self._try_add_session(session_id) + await async_try_add_session(self.req_sender, session_id) msg = dict( token_ids=input_ids, session_id=session_id, - max_request_output_len=request_output_len, sampling_param=sampling_param, adapter_name=adapter_name, ) - - req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) + req_id = await self.req_sender.async_send_async( + RequestType.ADD_MESSAGE, msg) token_ids = [] while True: - if not self.engine.loop_threads.is_alive(): + if not self.req_sender.is_loop_alive(): yield (ResponseType.ENGINE_STOP_ERROR, [], 0) break - resps = self.req_sender.recv_all(req_id) - if len(resps) == 0: - await asyncio.sleep(0.1) - continue + resp = await self.req_sender.async_recv(req_id) - resp_type = ResponseType.SUCCESS - for resp in resps: - resp_type = resp.type - if resp.type == ResponseType.SUCCESS: - token_ids += resp.data['token_ids'] - elif resp.type == ResponseType.FINISH: - token_ids += resp.data['token_ids'] - break - else: - token_ids = [] - break - yield (resp_type, token_ids, len(token_ids)) - if resp_type != ResponseType.SUCCESS: + if resp.req_id != req_id: + continue + if resp.type == ResponseType.SUCCESS: + token_ids += resp.data['token_ids'] + yield (resp.type, token_ids, len(token_ids)) + elif resp.type == ResponseType.FINISH: + token_ids += resp.data['token_ids'] + yield (resp.type, token_ids, len(token_ids)) break + else: + yield (resp.type, [], 0) + break + + async def async_infer(self, + session_id: int, + input_ids: List[int] = None, + gen_config: EngineGenerationConfig = None, + **kwargs): + """Send inference request. + + Args: + session_id (int): The session id. + input_ids (List[int]): The input token ids. + gen_config (EngineGenerationConfig): The sampling parameters. + + Returns: + int: Error flags. 0 if success. + List[int]: The streaming output tokens. + int: The number of the output tokens. + """ + token_ids = [] + async for outputs in self.async_stream_infer(session_id, + input_ids, + gen_config=gen_config, + **kwargs): + status, tmp_ids, _ = outputs + if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: + return (status, token_ids, len(token_ids)) + token_ids = tmp_ids + + return (0, token_ids, len(token_ids)) def stream_infer(self, session_id: int, @@ -975,10 +1084,25 @@ def stream_infer(self, int: The number of the output tokens. """ - # TODO: support input embedding, step + def __call_async(): + """call async.""" + coro_gen = self.async_stream_infer(session_id, input_ids, + gen_config, adapter_name, + **kwargs) + while True: + try: + yield self.req_sender.run_until_complete( + coro_gen.__anext__()) + except StopAsyncIteration: + break + + if not self.req_sender.is_thread_safe(): + yield from __call_async() + return + gen_config = gen_config or EngineGenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - self._try_add_session(session_id) + try_add_session(self.req_sender, session_id) msg = dict( token_ids=input_ids, session_id=session_id, @@ -989,11 +1113,12 @@ def stream_infer(self, token_ids = [] while True: - if not self.engine.loop_threads.is_alive(): + if not self.req_sender.is_loop_alive(): yield (ResponseType.ENGINE_STOP_ERROR, [], 0) break + resp = self.req_sender.recv(req_id) - # avoid token decoding and scheduling simultaneously + if resp.req_id != req_id: continue if resp.type == ResponseType.SUCCESS: @@ -1036,19 +1161,21 @@ def infer(self, return (0, token_ids, len(token_ids)) + async def async_end(self, session_id: int): + """End the given session.""" + return await async_end(self.req_sender, session_id) + def end(self, session_id: int): """End the given session.""" - resp = self.req_sender.send(RequestType.END_SESSION, - dict(session_id=session_id)) - _check_resp_success(resp, (f'Failed to end session: {session_id}. ' - f'Error: {resp.type}.')) + return end(self.req_sender, session_id) + + async def async_cancel(self, session_id: int): + """Stop current streaming inference.""" + return await async_cancel(self.req_sender, session_id) def cancel(self, session_id: int): """Stop current streaming inference.""" - resp = self.req_sender.send(RequestType.STOP_SESSION, - dict(session_id=session_id)) - _check_resp_success(resp, (f'Failed to cancel session: {session_id}. ' - f'Error: {resp.type}.')) + return cancel(self.req_sender, session_id) def decode(self, prompt_token_ids: List[List[int]]): """Return logits of context decoding. diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 008b1da2b1..d947f7afe6 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import asyncio import os from dataclasses import asdict, dataclass, field from typing import Any, Callable, Dict, List, Union @@ -343,7 +344,6 @@ def model_forward( use_origin=False, context=context, ) - stream.synchronize() return dict(logits=output['logits'], custom_outputs=context._outputs) @@ -405,6 +405,17 @@ def paging_adapters(self, weight_maps: List[AdapterWeightMap]): """paging adapter.""" raise NotImplementedError('Not implemented.') + async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + """ + raise NotImplementedError('Not implemented.') + def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -507,16 +518,8 @@ def paging_adapters(self, weight_maps: List[AdapterWeightMap]): weight_map.cache_adapter(lora_linears, cpu_caches) update_lora_linears(lora_linears, weight_maps, device='cuda') - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - + def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -530,6 +533,37 @@ def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output + def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + """ + output = self._forward_impl(inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + self.stream.synchronize() + return output + + async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (SwapMap): Cache maps to swap in. + swap_out_map (SwapMap): Cache maps to swap out. + """ + output = self._forward_impl(inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map) + await asyncio.get_event_loop().run_in_executor(None, + self.stream.synchronize) + return output + @dataclass class TPResponse: @@ -886,6 +920,7 @@ def _tp_model_loop( world_size=world_size, stream=stream, ) + stream.synchronize() if rank == 0: resp_output = output out_que.put(TPResponse(0, None, resp_output)) @@ -922,25 +957,45 @@ def _start_tp_process(rank: int, raise e +def _check_context_alive(mp_context: mp.ProcessContext): + """check context alive.""" + procs = mp_context.processes + for idx, p in enumerate(procs): + if not p.is_alive(): + raise RuntimeError(f'Rank[{idx}] failed.') + + def _queue_get_response(que: mp.Queue, mp_context: mp.ProcessContext, interval: float = 1.0): """get response.""" from multiprocessing.queues import Empty + while True: + try: + return que.get(timeout=interval) + except Empty: + _check_context_alive(mp_context) - def __check_context_alive(): - """check context alive.""" - procs = mp_context.processes - for idx, p in enumerate(procs): - if not p.is_alive(): - raise RuntimeError(f'Rank[{idx}] failed.') - while True: +async def _async_queue_get_response(que: mp.Queue, + mp_context: mp.ProcessContext, + interval: float = 1.0): + """get response.""" + from multiprocessing.queues import Empty + + def __try_que_get(): + """try que get.""" try: return que.get(timeout=interval) except Empty: - pass - __check_context_alive() + return None + + while True: + ret = await asyncio.get_event_loop().run_in_executor( + None, __try_que_get) + if ret is not None: + return ret + _check_context_alive(mp_context) class TPModelAgent(AutoModelAgent): @@ -1059,6 +1114,24 @@ def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, return resp.data + async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, + swap_out_map: SwapMap): + """model forward. + + Args: + inputs (Dict): The input data comes from _make_inputs. + swap_in_map (Dict[int, int]): Cache maps to swap in. + swap_out_map (Dict[int, int]): Cache maps to swap out. + """ + with torch.no_grad(): + self.tp_model_in_que.put((inputs, swap_in_map, swap_out_map)) + + resp: TPResponse = await _async_queue_get_response( + self.tp_model_out_que, self.mp_context) + if resp.ret_code != 0: + raise RuntimeError('tp forward failed.') + return resp.data + def build_model_agent(model_path: str, cache_config: CacheConfig, diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 047ced1043..f71c8a8028 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from queue import Empty, Queue from threading import Lock, Thread -from typing import Any, Callable, ClassVar, Dict, List +from typing import Any, Awaitable, Callable, Dict, List from lmdeploy.messages import ResponseType from lmdeploy.utils import get_logger @@ -12,6 +12,26 @@ logger = get_logger('lmdeploy') +def _raise_exception_on_finish(task: asyncio.Task) -> None: + msg = ('Engine loop failed!') + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise RuntimeError(msg) from exc + + +def _ignore_exception_on_finish(task: asyncio.Task) -> None: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + logger.info(f'task: {task.get_name()} ended.') + logger.debug(f'task: {task.get_name()} exception: {exc}') + + class RequestType(enum.Enum): """Request type.""" @@ -44,6 +64,21 @@ class Response: err_msg: str = '' +ReqList = List[Request] + + +def _run_until_complete(future: Awaitable): + """run untile complete.""" + try: + event_loop = asyncio.get_event_loop() + except Exception: + logger.warning('Can not found event loop in current thread.' + ' Create a new event loop.') + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + return event_loop.run_until_complete(future) + + @dataclass class RequestSender: """Request sender. @@ -53,53 +88,133 @@ class RequestSender: """ sender_id: int - req_que: Queue - resp_que: Queue = field(default_factory=Queue) + manager: 'RequestManager' resp_dict: Dict[int, List[Response]] = field(default_factory=dict) - THREAD_ALIVE_INTERVAL: ClassVar[float] = 1.0 _next_req_id: int = 0 - _thread: Thread = None + _resp_que: asyncio.Queue = None + _resp_thread_que: Queue = None @classmethod - def new(cls, sender_id: int, req_que: Queue, thread: Thread): - """new sender.""" - return cls(sender_id=sender_id, req_que=req_que, _thread=thread) - - def _resp_que_get(self, block: bool = True, timeout: float = None): - """warp of resp_que.get.""" - if not block: - return self.resp_que.get_nowait() - timeout_counter = timeout or float(1 << 30) - while timeout_counter > self.THREAD_ALIVE_INTERVAL: + def new(cls, sender_id: int, manager: 'RequestManager'): + """new.""" + return cls(sender_id=sender_id, manager=manager) + + @property + def resp_que(self): + """response queue.""" + if self.is_thread_safe(): + return self.manager.responses + if self.manager._loop_task is None and not self.is_thread_safe(): + self.manager.create_loop_task() + if self._resp_que is None: + self._resp_que = asyncio.Queue() + return self._resp_que + + @property + def req_que(self): + """request queue.""" + return self.manager.requests + + @property + def resp_thread_que(self): + """response threadsafe queue.""" + if self._resp_thread_que is None: + self._resp_thread_que = Queue() + return self._resp_thread_que + + @property + def req_thread_que(self): + """request threadsafe queue.""" + return self.manager.thread_requests + + @property + def event_loop(self): + """get event loop.""" + return self.manager.event_loop + + def is_thread_safe(self): + """is thread safe.""" + return self.manager.is_thread_safe() + + def is_loop_alive(self): + """is loop alive.""" + return self.manager.is_loop_alive() + + def run_until_complete(self, future: Awaitable): + """run untile complete.""" + return self.manager.run_until_complete(future) + + def _resp_get(self): + """resp_que.get.""" + timeout = 1 + + while True: + if not self.manager.is_loop_alive(): + logger.debug('Engine loop is not alive.') + exit(1) try: - return self.resp_que.get(timeout=self.THREAD_ALIVE_INTERVAL) + ret = self.resp_thread_que.get(timeout=timeout) + return ret except Empty: - timeout_counter -= self.THREAD_ALIVE_INTERVAL - if self._thread and not self._thread.is_alive(): - logger.error('Engine main loop stopped.') - exit(1) + continue + except Exception as e: + raise e + + async def _async_resp_get(self): + """get resp. + + Different behavior in threadsafe mode. + """ + timeout = 1 + + async def __no_threadsafe_get(): + while True: + if not self.manager.is_loop_alive(): + logger.debug('Engine loop is not alive.') + exit(1) + try: + return await asyncio.wait_for(self.resp_que.get(), timeout) + except asyncio.TimeoutError: + continue + except Exception as e: + raise e + + if self.is_thread_safe(): + ret = self._resp_get() + await asyncio.sleep(0) + return ret + else: + return await __no_threadsafe_get() - return self.resp_que.get(timeout=timeout_counter) - - async def _async_resp_que_get(self, - block: bool = True, - timeout: float = None): - """warp of resp_que.get.""" - if not block: - return self.resp_que.get_nowait() - timeout_counter = timeout or float(1 << 30) - while timeout_counter > self.THREAD_ALIVE_INTERVAL: - if self.resp_que.qsize() == 0: - await asyncio.sleep(self.THREAD_ALIVE_INTERVAL) - timeout_counter -= self.THREAD_ALIVE_INTERVAL - else: - return self.resp_que.get(block=False) - if self._thread and not self._thread.is_alive(): - logger.error('Engine main loop stopped.') - exit(1) + def _req_put(self, reqs: Any): + """req put.""" + self.req_thread_que.put(reqs) + + async def _async_req_put(self, reqs: Any): + """async rq_que put. - await asyncio.sleep(self.THREAD_ALIVE_INTERVAL) - return self.resp_que.get(block=False) + Different behavior in threadsafe mode. + """ + if self.is_thread_safe(): + self._req_put(reqs) + await asyncio.sleep(0) + else: + await self.req_que.put(reqs) + + def _prefetch_resps(self): + """prefetch from resp que. + + Different behavior in threadsafe mode. + """ + if self.is_thread_safe(): + resp_que = self.resp_thread_que + else: + resp_que = self.resp_que + num_resps = resp_que.qsize() + for _ in range(num_resps): + resp: Response = resp_que.get_nowait() + req_id = resp.req_id + self._push_resp(req_id, resp) def _push_resp(self, req_id: int, resp: Response): """push response.""" @@ -116,22 +231,11 @@ def _pop_resp(self, req_id: int, default: Any = None): self.resp_dict.pop(req_id) return ret - def _prefetch_resps(self): - """prefetch from resp que.""" - num_resps = self.resp_que.qsize() - for _ in range(num_resps): - resp: Response = self._resp_que_get(block=False) - req_id = resp.req_id - self._push_resp(req_id, resp) - - def is_thread_alive(self): - """is thread alive.""" - return self._thread and self._thread.is_alive() - - def batched_send_async(self, req_types: List[RequestType], - data: List[Any]) -> List[int]: - """Batched send request asynchronize.""" - if self._thread and not self._thread.is_alive(): + def _gather_request(self, req_types: List[RequestType], data: List[Any]): + """gather requests.""" + if self.manager._loop_task is None and not self.is_thread_safe(): + self.manager.create_loop_task() + if not self.is_loop_alive(): logger.error('Engine main loop stopped.') exit(1) assert len(req_types) == len(data) @@ -148,79 +252,120 @@ def batched_send_async(self, req_types: List[RequestType], data=rdata) for req_id, rtype, rdata in zip(req_ids, req_types, data) ] - self.req_que.put(reqs) + return req_ids, reqs + async def async_batched_send_async(self, req_types: List[RequestType], + data: List[Any]): + """Batched send request asynchronize.""" + req_ids, reqs = self._gather_request(req_types, data) + await self._async_req_put(reqs) + return req_ids + + async def async_send_async(self, req_type: RequestType, data: Any): + """send request asynchronize.""" + return (await self.async_batched_send_async(req_types=[req_type], + data=[data]))[0] + + def batched_send_async(self, req_types: List[RequestType], + data: List[Any]) -> List[int]: + """Batched send request asynchronize. + + Different behavior in threadsafe mode. + """ + if not self.is_thread_safe(): + coro = self.async_batched_send_async(req_types, data) + return self.run_until_complete(coro) + + req_ids, reqs = self._gather_request(req_types, data) + self._req_put(reqs) return req_ids def send_async(self, req_type: RequestType, data: Any) -> int: """send request asynchronize.""" return self.batched_send_async(req_types=[req_type], data=[data])[0] - def recv_any(self, que_timeout: float = None) -> Response: + async def async_recv_any(self, que_timeout: float = None) -> Response: """receive any response.""" - # check resp dict self._prefetch_resps() for req_id in self.resp_dict: ret = self._pop_resp(req_id, default=None) if ret is not None: return ret + return await self._async_resp_get() - # check resp que - return self._resp_que_get(timeout=que_timeout) + def recv_any(self, que_timeout: float = None) -> Response: + """receive any response.""" + coro = self.async_recv_any(que_timeout) + return self.run_until_complete(coro) - def recv_all(self, req_id: int): + def recv_all(self, req_id: int, block: bool = True): """revceive all response with req_id.""" self._prefetch_resps() resps = self.resp_dict.pop(req_id, []) return resps - def recv(self, req_id: int, que_timeout: float = None) -> Response: - """receive response of given request id.""" - # check resp dict + async def async_recv(self, + req_id: int, + que_timeout: float = None) -> Response: + """receive response of given request id async.""" ret = self._pop_resp(req_id, default=None) if ret is not None: return ret # check resp que while True: - resp: Response = self._resp_que_get(timeout=que_timeout) + resp: Response = await self._async_resp_get() if resp.req_id != req_id: self._push_resp(req_id, resp) else: return resp - async def async_recv(self, - req_id: int, - que_timeout: float = None) -> Response: - """receive response of given request id async.""" + def recv(self, req_id: int, que_timeout: float = None) -> Response: + """receive response of given request id. + + Different behavior in threadsafe mode. + """ + if not self.is_thread_safe(): + coro = self.async_recv(req_id, que_timeout) + return self.run_until_complete(coro) + ret = self._pop_resp(req_id, default=None) if ret is not None: return ret # check resp que while True: - resp: Response = await self._async_resp_que_get(timeout=que_timeout - ) + resp: Response = self._resp_get() if resp.req_id != req_id: self._push_resp(req_id, resp) else: return resp + async def async_send(self, + req_type: RequestType, + data: Any, + que_timeout: float = None): + """send and receive synchronize.""" + req_id = await self.async_send_async(req_type, data) + return await self.async_recv(req_id, que_timeout=que_timeout) + def send(self, req_type: RequestType, data: Any, que_timeout: float = None) -> Response: """send and receive synchronize.""" req_id = self.send_async(req_type, data) - return self.recv(req_id, que_timeout=que_timeout) + def response_callback(self, resp: Response): + """response callback.""" + self.resp_que.put_nowait(resp) + class RequestManager: """Request manager.""" - def __init__(self): - self._next_sender_id = 0 + def __init__(self, thread_safe: bool = False): self.senders: Dict[int, RequestSender] = dict() self.callbacks: Dict[RequestType, Callable] = dict() self.request_priority: List[RequestType] = [ @@ -228,47 +373,158 @@ def __init__(self): RequestType.END_SESSION, RequestType.ADD_SESSION, RequestType.ADD_MESSAGE ] - self.requests = Queue() - self.mutex = Lock() - - def build_sender(self, thread: Thread = None): + self.requests: asyncio.Queue = None + self._loop_task: asyncio.Future = None + self._loop_coro: Callable = None + self._thread_safe = thread_safe + self._next_sender_id = 0 + self._mutex = Lock() + self._loop_thread: Thread = None + + self.thread_requests: Queue = None + # every sender has it's own responses, this responses is + # only used in thread safe mode. + self.responses: asyncio.Queue = None + if thread_safe: + self.thread_requests = Queue() + + def create_loop_task(self): + """create coro task.""" + logger.debug('creating engine loop task.') + event_loop = asyncio.get_event_loop() + assert self._loop_coro is not None, ( + 'Please set loop task with manager.start_loop') + loop_unshielded = event_loop.create_task(self._loop_coro(), + name='EngineMainLoop') + loop_unshielded.add_done_callback(_raise_exception_on_finish) + self._loop_task = asyncio.shield(loop_unshielded) + self.requests = asyncio.Queue() + return self._loop_task + + @property + def event_loop(self): + """get event loop.""" + if self._loop_task is None: + return None + else: + return self._loop_task.get_loop() + + def is_thread_safe(self): + """is thread safe.""" + return self._thread_safe + + def start_loop(self, loop: asyncio.Task): + """start main loop.""" + self._loop_coro = loop + + def __get_thread_reqs(): + """get thread reqs.""" + num_reqs = self.thread_requests.qsize() + reqs = [] + for _ in range(num_reqs): + tmp_reqs = self.thread_requests.get_nowait() + if isinstance(tmp_reqs, Request): + tmp_reqs = [tmp_reqs] + reqs += tmp_reqs + return reqs + + async def __req_loop(): + """req loop.""" + while True: + # get reqs + reqs = __get_thread_reqs() + + if len(reqs) > 0: + await self.requests.put(reqs) + else: + await asyncio.sleep(0.02) + + def __put_thread_resps(resps: List[Response]): + """put thread resps.""" + for resp in resps: + sender = self.senders.get(resp.sender_id, None) + if sender is None: + continue + sender.resp_thread_que.put_nowait(resp) + + async def __resp_loop(): + """resp loop.""" + while True: + num_resps = self.responses.qsize() + resps = [] + for _ in range(num_resps): + resps.append(self.responses.get_nowait()) + if len(resps) > 0: + __put_thread_resps(resps) + else: + await asyncio.sleep(0.02) + + def __run_forever(event_loop: asyncio.BaseEventLoop): + """run forever.""" + logger.debug('start thread run forever.') + asyncio.set_event_loop(event_loop) + self.create_loop_task() + req_loop = event_loop.create_task(__req_loop(), + name='RunForeverReqLoop') + req_loop.add_done_callback(_ignore_exception_on_finish) + resp_loop = event_loop.create_task(__resp_loop(), + name='RunForeverRespLoop') + resp_loop.add_done_callback(_ignore_exception_on_finish) + self.event_loop.run_forever() + + if self.is_thread_safe(): + event_loop = asyncio.new_event_loop() + self.responses = asyncio.Queue() + self._loop_thread = Thread(target=__run_forever, + args=(event_loop, ), + daemon=True) + self._loop_thread.start() + + def is_loop_alive(self): + """check if main loop is alive.""" + + def __check_threadsafe(): + if self._loop_thread is None: + return False + if not self._loop_thread.is_alive(): + return False + if self._loop_task is None: + return False + return not self._loop_task.done() + + if self.is_thread_safe(): + return __check_threadsafe() + + if self._loop_task is None: + logger.debug('loop task has not been created.') + return False + if self._loop_task.get_loop() != asyncio.get_event_loop(): + logger.warning('Current event loop is different from' + ' the one bound to loop task!') + return False + return not self._loop_task.done() + + def build_sender(self): """create a new sender.""" - with self.mutex: + with self._mutex: sender_id = self._next_sender_id self._next_sender_id += 1 - new_sender = RequestSender.new(sender_id, self.requests, thread) + new_sender = RequestSender.new(sender_id, self) self.senders[sender_id] = new_sender return new_sender - def bind_func(self, req_type: RequestType, callback: Callable): - """bind handler for given request type.""" - self.callbacks[req_type] = callback - - def set_request_priority(self, priority: List[RequestType]): - """set the priority of request type.""" - self.request_priority = priority - def has_requests(self): """has unprocessed request.""" + if self.requests is None: + return False return not self.requests.empty() - def response(self, resp: Response, timeout: float = None): - """send response.""" - if resp.sender_id not in self.senders: - logger.warning(f'sender {resp.sender_id} not exist. ' - f'Send {resp} failed.') - return - resp_que = self.senders[resp.sender_id].resp_que - resp_que.put(resp, timeout=timeout) - def get_all_requests(self) -> Dict[RequestType, Request]: """get all requests in current queue.""" num_reqs = self.requests.qsize() - reqs: List[Request] = [] - tmp = num_reqs - while tmp: - tmp -= 1 - elem = self.requests.get() + reqs: ReqList = [] + for _ in range(num_reqs): + elem = self.requests.get_nowait() if isinstance(elem, Request): elem = [elem] reqs += elem @@ -280,7 +536,23 @@ def get_all_requests(self) -> Dict[RequestType, Request]: reqs_by_type[req.type].append(req) return reqs_by_type - def process_request(self, req_type, reqs, **kwargs): + def bind_func(self, req_type: RequestType, callback: Callable): + """bind handler for given request type.""" + self.callbacks[req_type] = callback + + def set_request_priority(self, priority: List[RequestType]): + """set the priority of request type.""" + self.request_priority = priority + + def response(self, resp: Response): + """send response.""" + if resp.sender_id not in self.senders: + logger.warning(f'sender {resp.sender_id} not exist. ' + f'Send {resp} failed.') + return + self.senders[resp.sender_id].response_callback(resp) + + def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs): """process reqs with given req type.""" # get callback func = self.callbacks.get(req_type, None) @@ -297,7 +569,10 @@ def process_request(self, req_type, reqs, **kwargs): self.response(resp) def step(self, **kwargs): - """handle requests.""" + """handle requests. + + Should only be called in loop task. + """ reqs_by_type = self.get_all_requests() # handle requests @@ -306,5 +581,9 @@ def step(self, **kwargs): if req_type not in reqs_by_type or len(reqs_by_type) == 0: continue - reqs: List[Request] = reqs_by_type[req_type] + reqs: ReqList = reqs_by_type[req_type] self.process_request(req_type, reqs, **kwargs) + + def run_until_complete(self, future: Awaitable): + """run untile complete.""" + return _run_until_complete(future) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index b58415b65e..19efe95719 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -75,7 +75,7 @@ def from_gen_config(self, gen_config: EngineGenerationConfig): if max_new_tokens < 0: logger.warning('`max_new_tokens` has to be a strictly' f' positive value, but is {max_new_tokens}') - max_new_tokens = 0 + max_new_tokens = 512 if min_new_tokens < 0 or min_new_tokens > max_new_tokens: logger.warning('`min_new_tokens` has to be ' 'a int >=0 and <= `max_new_tokens`,' diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 84b588c531..3e622d604c 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -4,7 +4,7 @@ import os import random from argparse import ArgumentError -from contextlib import contextmanager +from contextlib import asynccontextmanager from queue import Empty, Queue from threading import Thread from typing import Dict, List, Literal, Optional, Union @@ -257,30 +257,30 @@ def __call__(self, do_preprocess=do_preprocess, **kwargs) - def stop_session(self, session_id: int): + async def stop_session(self, session_id: int): """Stop a session by a session_id.""" if str(session_id) in self.id2generator: - self.id2generator[str(session_id)].cancel(session_id) + await self.id2generator[str(session_id)].async_cancel(session_id) self.gens_set.add(self.id2generator[str(session_id)]) self.running_session_ids.discard(session_id) - def end_session(self, session_id: int): + async def end_session(self, session_id: int): """Clear a session by a session_id.""" if str(session_id) in self.id2generator: - self.id2generator[str(session_id)].end(session_id) + await self.id2generator[str(session_id)].async_end(session_id) self.id2step[str(session_id)] = 0 self.gens_set.add(self.id2generator[str(session_id)]) self.running_session_ids.discard(session_id) - @contextmanager - def safe_run(self, session_id: Optional[int] = None): + @asynccontextmanager + async def safe_run(self, session_id: Optional[int] = None): """A context manager to make sure server's safe running.""" try: yield except (Exception, asyncio.CancelledError) as e: # noqa - self.stop_session(session_id) + await self.stop_session(session_id) raise e if str(session_id) in self.id2generator: self.gens_set.add(self.id2generator[str(session_id)]) @@ -292,7 +292,7 @@ async def get_generator(self, stop: bool, session_id: int): return self.engine.create_instance() # waiting no generator is available or the same session_id is running while self.gens_set == set() or session_id in self.running_session_ids: - await asyncio.sleep(0) + await asyncio.sleep(0.1) generator = self.gens_set.pop() self.id2generator[str(session_id)] = generator self.running_session_ids.add(session_id) @@ -493,10 +493,10 @@ async def generate( yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0, finish_reason) if sequence_end is True and sequence_start is False: - self.end_session(session_id) + await self.end_session(session_id) else: generator = await self.get_generator(False, session_id) - with self.safe_run(session_id): + async with self.safe_run(session_id): state = DetokenizeState() async for outputs in generator.async_stream_infer( session_id=session_id, @@ -532,4 +532,4 @@ async def generate( # manually end pytorch session # TODO modify pytorch or turbomind api if self.backend == 'pytorch' and sequence_end: - self.end_session(session_id) + await self.end_session(session_id) diff --git a/lmdeploy/serve/gradio/turbomind_coupled.py b/lmdeploy/serve/gradio/turbomind_coupled.py index 0371c0b11f..be522cb2b3 100644 --- a/lmdeploy/serve/gradio/turbomind_coupled.py +++ b/lmdeploy/serve/gradio/turbomind_coupled.py @@ -78,7 +78,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox, """ state_chatbot = [] # end the session - InterFace.async_engine.end_session(session_id) + await InterFace.async_engine.end_session(session_id) return (state_chatbot, state_chatbot, gr.Textbox.update(value='')) @@ -94,12 +94,12 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button, session_id (int): the session id """ yield (state_chatbot, disable_btn, disable_btn) - InterFace.async_engine.stop_session(session_id) + await InterFace.async_engine.stop_session(session_id) # pytorch backend does not support resume chat history now if InterFace.async_engine.backend == 'pytorch': yield (state_chatbot, disable_btn, enable_btn) else: - InterFace.async_engine.end_session(session_id) + await InterFace.async_engine.end_session(session_id) messages = [] for qa in state_chatbot: messages.append(dict(role='user', content=qa[0])) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 8b8969c45e..2ecd1cdc96 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -216,7 +216,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.stop_session( + request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res @@ -376,7 +377,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.stop_session( + request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res @@ -522,7 +524,8 @@ async def _inner_call(i, generator): async for res in generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.stop_session( + request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res @@ -688,7 +691,8 @@ async def _inner_call(i, generator): async for res in generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.stop_session( + request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res @@ -825,7 +829,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for out in generation: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - VariableInterface.qos_engine.stop_session(request.session_id) + await VariableInterface.qos_engine.stop_session( + request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') text += out.response @@ -870,7 +875,8 @@ async def chat_interactive_v1(request: GenerateRequest, """ if request.cancel: if request.session_id != -1: - VariableInterface.async_engine.stop_session(request.session_id) + await VariableInterface.async_engine.stop_session( + request.session_id) return { 'text': '', 'tokens': 0, @@ -879,7 +885,7 @@ async def chat_interactive_v1(request: GenerateRequest, 'finish_reason': 'stop' } else: - create_error_response( + return create_error_response( HTTPStatus.BAD_REQUEST, 'please set a session_id to cancel a request') if request.session_id == -1: @@ -931,7 +937,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for out in generation: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - async_engine.stop_session(request.session_id) + await async_engine.stop_session(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') text += out.response diff --git a/lmdeploy/serve/qos_engine/qos_engine.py b/lmdeploy/serve/qos_engine/qos_engine.py index bd6c8bf8a7..1a311bc0f2 100644 --- a/lmdeploy/serve/qos_engine/qos_engine.py +++ b/lmdeploy/serve/qos_engine/qos_engine.py @@ -65,9 +65,9 @@ def is_qos_enabled(self): """check while qos engine is enabled.""" return self.qos_config.is_qos_enabled - def stop_session(self, session_id: int): + async def stop_session(self, session_id: int): """Stop a session by a session_id.""" - self.engine.stop_session(session_id) + await self.engine.stop_session(session_id) async def generate(self, request): """entry of qos engine generate for three api.""" diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index cc4281f723..77032e61d3 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -566,6 +566,11 @@ def end(self, session_id: int): sequence_end=True): pass + async def async_end(self, session_id: int): + """End the given session.""" + self.end(session_id) + await asyncio.sleep(0.002) + def cancel(self, session_id: int): """Stop current streaming inference.""" input_ids = [self.tm_model.tokenizer.eos_token_id] @@ -578,6 +583,11 @@ def cancel(self, session_id: int): stop=True): pass + async def async_cancel(self, session_id: int): + """End the given session.""" + self.cancel(session_id) + await asyncio.sleep(0.002) + def prepare_inputs(self, session_id, input_ids, diff --git a/tests/pytorch/engine/test_request.py b/tests/pytorch/engine/test_request.py index 66dbd55783..813a30e8e7 100644 --- a/tests/pytorch/engine/test_request.py +++ b/tests/pytorch/engine/test_request.py @@ -1,29 +1,59 @@ +import asyncio + import pytest from lmdeploy.pytorch.engine.request import (RequestManager, RequestType, - ResponseType) + Response, ResponseType) class TestRequestHander: @pytest.fixture - def manager(self): - yield RequestManager() + def event_loop(self): + old_loop = asyncio.get_event_loop() + new_loop = asyncio.new_event_loop() + yield new_loop + new_loop.stop() + asyncio.set_event_loop(old_loop) - def _stop_engine_callback(self, reqs, **kwargs): - raise RuntimeError('stop_engine') + @pytest.fixture + def thread_safe(self, request): + yield request.param - def test_bind(self, manager): + @pytest.fixture + def manager(self, thread_safe): + yield RequestManager(thread_safe=thread_safe) + + @pytest.mark.parametrize('thread_safe', [True, False]) + def test_bind(self, manager, event_loop): + + def __stop_engine_callback(reqs, **kwargs): + for req in reqs: + manager.response( + Response(type=ResponseType.SUCCESS, + sender_id=req.sender_id, + req_id=req.req_id, + data=f'{req.data} success')) + + async def __dummy_loop(): + while True: + manager.step() + await asyncio.sleep(0.1) + + asyncio.set_event_loop(event_loop) sender = manager.build_sender() + manager.start_loop(__dummy_loop) # test not bind req_id = sender.send_async(RequestType.STOP_ENGINE, None) - manager.step() resp = sender.recv(req_id) assert resp.type == ResponseType.HANDLER_NOT_EXIST + assert manager.is_loop_alive() + # test bind success sender.send_async(RequestType.STOP_ENGINE, None) - manager.bind_func(RequestType.STOP_ENGINE, self._stop_engine_callback) - with pytest.raises(RuntimeError): - manager.step() + manager.bind_func(RequestType.STOP_ENGINE, __stop_engine_callback) + req_id = sender.send_async(RequestType.STOP_ENGINE, 'test') + resp = sender.recv(req_id) + assert resp.data == 'test success' From 5aaeb5c349154d4d8c1d2e0fa504319847aa04c9 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 4 Mar 2024 12:20:09 +0800 Subject: [PATCH 04/17] fix returning logits in prefill phase of pytorch engine (#1209) * fix decode * fix * fix lint * end on start * fix --- lmdeploy/pytorch/engine/engine.py | 169 ++++++++++++++++-------------- lmdeploy/pytorch/messages.py | 12 ++- 2 files changed, 100 insertions(+), 81 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 34a6679458..31403d7d63 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -14,8 +14,7 @@ from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter from ..check_env import check_env, check_model from ..config import CacheConfig, SchedulerConfig -from ..messages import (MessageStatus, SamplingParam, SchedulerSequence, - SchedulerSession) +from ..messages import MessageStatus, SamplingParam, SchedulerSequence from ..paging import Scheduler from .logits_process import FusedLogitsProcessor, SamplingInputs from .model_agent import AutoModelAgent, ModelInputs @@ -157,11 +156,14 @@ class Engine: def __init__(self, model_path: str, - engine_config: PytorchEngineConfig, + engine_config: PytorchEngineConfig = None, trust_remote_code: bool = True) -> None: check_env() check_model(model_path) + if engine_config is None: + engine_config = PytorchEngineConfig() + self.engine_config = engine_config model_name = engine_config.model_name tp = engine_config.tp @@ -219,7 +221,7 @@ def __init__(self, @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, - engine_config: PytorchEngineConfig, + engine_config: PytorchEngineConfig = None, trust_remote_code: bool = True, **kwargs): """lmdeploy python inference engine. @@ -336,7 +338,9 @@ def __update_bad_words(msg): req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence(req.data['token_ids'], sampling_param=req.data['sampling_param'], - adapter_name=req.data['adapter_name']) + adapter_name=req.data['adapter_name'], + return_logits=req.data.get( + 'return_logits', False)) msg = next(iter(sess.sequences.values())) __update_bad_words(msg) self.scheduler.add_sequence(msg) @@ -345,6 +349,7 @@ def __update_bad_words(msg): msg.update_token_ids(req.data['token_ids']) msg.num_new_tokens = 0 msg.sampling_param = req.data['sampling_param'] + msg.return_logits = req.data.get('return_logits', False) msg.status = MessageStatus.WAITING __update_bad_words(msg) @@ -695,9 +700,6 @@ async def __long_context_forward(inputs): async def async_step(self, is_prefill: bool, return_logits: bool = False): """one step inference. Used to perform streaming chat. - Args: - return_logits (bool): Whether to return the output logits. - Returns: Dict[int, InferOutput]: The output of each session. """ @@ -721,7 +723,7 @@ async def async_step(self, is_prefill: bool, return_logits: bool = False): logits = output['logits'] logits = logits[0] # [bs, seq, prob] -> [seq, prob] - next_token_ids, split_logits = await self.async_sampling_logits( + next_token_ids, _ = await self.async_sampling_logits( logits, running, inputs) self.update_running(running, next_token_ids, custom_outputs) @@ -729,7 +731,8 @@ async def async_step(self, is_prefill: bool, return_logits: bool = False): # generate output outputs: Dict[int, InferOutput] = dict() - for msg, next_id in zip(running, next_token_ids): + for idx, msg in enumerate(running): + next_id = next_token_ids[idx] session_id = msg.session_id if self._can_output_token(next_id, msg): out_token_ids = [next_id.item()] @@ -744,9 +747,10 @@ async def async_step(self, is_prefill: bool, return_logits: bool = False): ) outputs[session_id] = out - if return_logits: - for msg, msg_logit in zip(running, split_logits): - outputs[msg.session_id].logits = msg_logit + if msg.return_logits: + start = inputs.q_start_loc[idx] + seqlen = inputs.seq_length[idx] + outputs[msg.session_id].logits = logits[start:start + seqlen] return outputs async def async_batched_infer(self, @@ -830,75 +834,77 @@ async def _add_messages(session_ids, token_ids): output_token_len = [len(token_ids) for token_ids in output_token_ids] return (status, output_token_ids, output_token_len) - def batched_infer(self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: EngineGenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False): - """Send inference request. + def decode(self, + input_ids, + steps: List[int] = None, + sequence_start: bool = True, + sequence_end: bool = True, + adapter_names: List[str] = None): + """Perform context decode on input tokens. Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (EngineGenerationConfig): The sampling parameters. + input_ids (numpy.ndarray): the batch of input token ids + steps (List[int]): the offset of the k/v cache + sequence_start (bool): indicator for starting a sequence + sequence_end (bool): indicator for ending a sequence adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. """ - coro = self.async_batched_infer(session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - keep_cache=keep_cache) - return self.req_sender.run_until_complete(coro) + from torch.nn.utils.rnn import pad_sequence + logger.debug('Decoding logits.') + batch_size = len(input_ids) - def decode(self, prompt_token_ids: List[List[int]]): - """Perform one step inference and get logits. - - Args: - prompt_token_ids (List[List[int]]): Input prompts. + def __add_messages(session_ids, input_ids, adapter_names): + add_msgs = [] + sampling_param = SamplingParam(max_new_tokens=0) + for session_id, token_id, adapter_name in zip( + session_ids, input_ids, adapter_names): + msg = dict(token_ids=token_id, + session_id=session_id, + sampling_param=sampling_param, + adapter_name=adapter_name, + return_logits=True) + add_msgs.append(msg) + req_types = [RequestType.ADD_MESSAGE] * batch_size + req_ids = self.req_sender.batched_send_async(req_types, + data=add_msgs) + return req_ids - Returns: - List[Tensor]: The logits. - """ - assert not self.scheduler.has_unfinished() + if steps is not None: + assert batch_size == len(steps) - if len(self.scheduler.sessions) > 0: - logger.warning( - 'Unreleased session might leads to low performance.') + if adapter_names is None: + adapter_names = [None] * batch_size + assert batch_size == len(adapter_names) - session_id = 1 - sessions: List[SchedulerSession] = [] - while len(sessions) < len(prompt_token_ids): - while session_id in self.scheduler.sessions: - session_id += 1 - sess = SchedulerSession(session_id) - sessions.append(sess) - self.add_session(sess) + session_ids = tuple(range(batch_size)) + if sequence_start: + for sid in session_ids: + self.req_sender.send(RequestType.END_SESSION, + dict(session_id=sid)) + self.add_session(sid) - msgs: SeqList = [] - for token_ids, sess in zip(prompt_token_ids, sessions): - msg = sess.add_sequence(token_ids=token_ids) - msgs.append(msg) - self.scheduler.add_sequence(msg) + req_ids = __add_messages(session_ids, input_ids, adapter_names) + req_idx_map = dict(zip(req_ids, range(len(req_ids)))) - outputs = self.step(return_logits=True) + finish_count = batch_size + ret = [None] * batch_size + while finish_count > 0: + resp = self.req_sender.recv_any() + if resp.req_id not in req_ids: + continue - logits = dict((k, out.logits) for k, out in outputs.items()) + assert resp.type == ResponseType.FINISH + idx = req_idx_map[resp.req_id] + ret[idx] = resp.data['logits'] + finish_count -= 1 - for sess in sessions: - self.end_session(sess.session_id) + ret = pad_sequence(ret, True) - split_logits = [logits[sess.session_id] for sess in sessions] - pad_sequence = torch.nn.utils.rnn.pad_sequence - logits = pad_sequence(split_logits, True) + if sequence_end: + for sid in session_ids: + self.end_session(sid) - return logits + return ret async def async_loop(self): """Main loop of the engine. @@ -918,7 +924,7 @@ def _send_resp(step_tokens): type=resp_type, sender_id=out.sender_id, req_id=out.req_id, - data=dict(token_ids=out.token_ids), + data=dict(token_ids=out.token_ids, logits=out.logits), )) prefill_interval = self.scheduler_config.prefill_interval @@ -1177,14 +1183,23 @@ def cancel(self, session_id: int): """Stop current streaming inference.""" return cancel(self.req_sender, session_id) - def decode(self, prompt_token_ids: List[List[int]]): - """Return logits of context decoding. + def decode(self, + input_ids, + steps: List[int] = None, + sequence_start: bool = True, + sequence_end: bool = True, + adapter_names: List[str] = None): + """Perform context decode on input tokens. Args: - prompt_token_ids: token ids of a batch prompts. - - Returns: - logits (numpy.ndarray) with shape - [batch, n_max_token_of_the_batch, vocab_size] + input_ids (numpy.ndarray): the batch of input token ids + steps (List[int]): the offset of the k/v cache + sequence_start (bool): indicator for starting a sequence + sequence_end (bool): indicator for ending a sequence + adapter_names (List[str]): The name of the adapters. """ - return self.engine.decode(prompt_token_ids) + return self.engine.decode(input_ids, + steps=steps, + sequence_start=sequence_start, + sequence_end=sequence_end, + adapter_names=adapter_names) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 19efe95719..4bae383740 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -25,8 +25,8 @@ class SamplingParam: repetition_penalty: float = 1.0 ignore_eos: bool = False random_seed: int = None - stop_words: List[int] = None - bad_words: List[int] = None + stop_words: List[int] = field(default_factory=list) + bad_words: List[int] = field(default_factory=list) max_new_tokens: int = 512 min_new_tokens: int = 0 @@ -126,7 +126,8 @@ def __init__(self, session_id: int, block_size: int) -> None: def add_sequence(self, token_ids: Tensor, sampling_param: SamplingParam = None, - adapter_name: str = None) -> 'SchedulerSequence': + adapter_name: str = None, + return_logits: bool = False) -> 'SchedulerSequence': """Add a new message.""" if not isinstance(token_ids, Tensor): token_ids = torch.tensor(token_ids) @@ -143,7 +144,8 @@ def add_sequence(self, num_new_tokens=0, sampling_param=sampling_param, adapter_name=adapter_name, - arrive_time=time.time()) + arrive_time=time.time(), + return_logits=return_logits) self.sequences[seq.seq_id] = seq return seq @@ -174,6 +176,7 @@ def fork_sequence( adapter_name=seq.adapter_name, arrive_time=time.time(), meta=deepcopy(seq.meta), + return_logits=seq.return_logits, random_offsets=seq.random_offsets + 1) self.sequences[new_msg.seq_id] = new_msg @@ -198,6 +201,7 @@ class SchedulerSequence: adapter_name: str = None arrive_time: float = 0.0 meta: Any = None + return_logits: bool = False random_offsets: int = 0 @property From 003f3c547af31f3e10782fc1e63da4de7f3e8795 Mon Sep 17 00:00:00 2001 From: zhulinJulia24 <145004780+zhulinJulia24@users.noreply.github.com> Date: Mon, 4 Mar 2024 13:09:05 +0800 Subject: [PATCH 05/17] downgrade pytest to v8.0.2 to resolve PR workflow (#1236) --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index d06440a9d7..bcdf679df3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,7 +1,7 @@ allure-pytest coverage pynvml -pytest +pytest==8.0.2 pytest-assume pytest-order pytest-rerunfailures From a6e81882779efb4ffe725967c7c09d03bc95e77e Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 4 Mar 2024 13:44:09 +0800 Subject: [PATCH 06/17] optimize pytorch engine inference with falcon model (#1234) * optimize falcon * checkout trust_remote_code --- lmdeploy/pytorch/check_env/__init__.py | 8 +- lmdeploy/pytorch/engine/cache_engine.py | 5 +- lmdeploy/pytorch/engine/engine.py | 2 +- lmdeploy/pytorch/engine/model_agent.py | 2 +- lmdeploy/pytorch/models/falcon.py | 295 +++++++----------------- lmdeploy/pytorch/models/gemma.py | 2 +- lmdeploy/pytorch/models/llama.py | 2 +- lmdeploy/pytorch/models/module_map.py | 2 - lmdeploy/pytorch/models/patch.py | 2 + 9 files changed, 102 insertions(+), 218 deletions(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 5d12bacb0a..62ba7f11f7 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -61,7 +61,8 @@ def check_env(): check_env_triton() -def check_transformers_version(model_path: str): +def check_transformers_version(model_path: str, + trust_remote_code: bool = True): """check transformers version.""" from packaging import version logger = get_logger('lmdeploy') @@ -77,7 +78,8 @@ def check_transformers_version(model_path: str): model_trans_version = None try: from transformers import AutoConfig - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code) model_trans_version = getattr(config, 'transformers_version') except Exception as e: message = ( @@ -97,7 +99,7 @@ def check_transformers_version(model_path: str): _handle_exception(e, 'transformers', logger, message=message) -def check_model(model_path: str): +def check_model(model_path: str, trust_remote_code: bool = True): """check model requirements.""" logger = get_logger('lmdeploy') logger.info('Checking model.') diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index 8df2cbdc27..20f516b2e0 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -201,7 +201,8 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: @staticmethod def get_cache_block_size(block_size: int, - model_config: ModelConfig) -> int: + model_config: ModelConfig, + world_size: int = 1) -> int: """Get the required cache size of the model. Args: @@ -214,6 +215,8 @@ def get_cache_block_size(block_size: int, head_size = model_config.get_head_size() num_layers = model_config.num_layers num_heads = model_config.num_key_value_heads + if not model_config.multi_query_attention: + num_heads = num_heads // world_size key_cache_block = block_size * num_heads * head_size value_cache_block = key_cache_block diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 31403d7d63..fe5bf7e0c2 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -159,7 +159,7 @@ def __init__(self, engine_config: PytorchEngineConfig = None, trust_remote_code: bool = True) -> None: check_env() - check_model(model_path) + check_model(model_path, trust_remote_code) if engine_config is None: engine_config = PytorchEngineConfig() diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d947f7afe6..537d17c072 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -59,7 +59,7 @@ def _update_cache_config(model_config: ModelConfig, gpu_mem = gpu_mem_physical_free * cache_config.cache_max_entry_count cpu_mem = host_mem_size cache_block_size = CacheEngine.get_cache_block_size( - cache_config.block_size, model_config) // world_size + cache_config.block_size, model_config, world_size) if cache_config.num_cpu_blocks == 0: cache_config.num_cpu_blocks = int(cpu_mem / cache_block_size) if cache_config.num_gpu_blocks == 0: diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index 60f663e485..aa63fb2625 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -12,71 +12,11 @@ from torch.distributed._tensor import DeviceMesh from transformers.modeling_outputs import \ BaseModelOutputWithPastAndCrossAttentions -from transformers.models.falcon.modeling_falcon import build_alibi_tensor from ..dist_utils import (colwise_parallelize_linear_fn, rowwise_parallelize_linear_fn) -from ..kernels import (alibi_paged_attention_fwd, fill_kv_cache, - paged_attention_fwd) - - -# rotary pos emb helpers -# (torch.jit.script does not seem to support staticmethod...) -def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -class PatchedFalconRotaryEmbedding(nn.Module): - """Implementation adapted from Huggingface transformers.""" - - def _patched_set_cos_sin_cache(self, seq_len, device, dtype): - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(device) - - if dtype in [torch.float16, torch.bfloat16]: - emb = emb.float() - - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] - - self.cos_cached = self.cos_cached.type(dtype) - self.sin_cached = self.sin_cached.type(dtype) - - def patched_cos_sin(self, - position_ids_1d: torch.Tensor, - device='cpu', - dtype=torch.bfloat16) -> torch.Tensor: - total_length = int(position_ids_1d.max().item()) + 1 - - if (self.seq_len_cached is None) or (total_length > - self.seq_len_cached): - self._patched_set_cos_sin_cache(total_length, device, dtype) - # position_ids.shape == [1, packed_seq_len] - # position_ids_1d.shape == [packed_seq_len] - return ( - self.cos_cached[:, position_ids_1d, None, :], - self.sin_cached[:, position_ids_1d, None, :], - ) - - def _contiguous_batching_forward(self, query: torch.Tensor, - key: torch.Tensor, - position_ids_1d: torch.Tensor): - # batch, seq_len, *_ = query.shape - cos, sin = self.patched_cos_sin(position_ids_1d, - device=query.device, - dtype=query.dtype) - return ( - (query * cos) + (rotate_half(query) * sin), - (key * cos) + (rotate_half(key) * sin), - ) - - def forward(self, query, key, position_ids_or_past_key_values_length=0): - """forward.""" - return self._contiguous_batching_forward( - query, key, position_ids_or_past_key_values_length) +from ..kernels import (alibi_paged_attention_fwd, apply_rotary_pos_emb, + fill_kv_cache, fused_rotary_emb, paged_attention_fwd) class PatchedFalconAttention(nn.Module): @@ -93,7 +33,7 @@ def _distribute_partition_fn(self, mod_name: str, mod: nn.Module, # e.g. 40b-instruct, GQA # split qkv across groups # no finer-grained partitioning - weight = mod.weight.reshape( + mod.weight.data = mod.weight.reshape( -1, # num groups (self.num_heads + self.num_kv_heads * 2) * self.head_dim, self.hidden_size, @@ -101,31 +41,22 @@ def _distribute_partition_fn(self, mod_name: str, mod: nn.Module, elif self.multi_query: # e.g. 7b-instruct, MQA # split to q, copy kv - weight = mod.weight.reshape( - -1, - self.head_dim, - self.hidden_size, - ) + weight = mod.weight.unflatten(0, (-1, self.head_dim)) q_weight = weight[:self.num_heads] - k_weight = weight[self.num_heads:self.num_heads + 1] - v_weight = weight[self.num_heads + 1:self.num_heads + 2] - q_weight_shards = torch.tensor_split(q_weight, - world_size, - dim=0) + kv_weight = weight[-2:] + q_weight_shards = q_weight.chunk(world_size, 0) weight_shards = [] for q in q_weight_shards: # only shard q heads but # copy single k/v head to all ranks weight_shards.append(q) - weight_shards.append(k_weight) - weight_shards.append(v_weight) + weight_shards.append(kv_weight) mod.weight.data = torch.cat(weight_shards, dim=0) # here we keep the weight to be 3D, # so that column parallel will split it # into integer-numbered heads # no bias for 7b-instruct and 40b-instruct - colwise_parallelize_linear_fn(mod, device_mesh=device_mesh, to_local=True) @@ -137,7 +68,7 @@ def _distribute_partition_fn(self, mod_name: str, mod: nn.Module, elif mod_name in ['dense']: if self.new_decoder_architecture: # e.g. 40b-instruct, GQA - weight = mod.weight.reshape( + mod.weight.data = mod.weight.reshape( self.hidden_size, -1, # num groups self.num_heads * self.head_dim, @@ -202,79 +133,80 @@ def _split_heads( 2, :] else: # e.g. 7b-instruct model - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - if not dist.is_initialized(): - num_head = self.num_heads - else: - # this trick will, for example, split 11 into [4, 4, 3] - # following the way column parallel linear splitting - # non-dividable dims - num_head = self.num_heads - dist.get_rank() - 1 - num_head = 1 + num_head // dist.get_world_size() - fused_qkv = fused_qkv.view(batch_size, seq_length, num_head + 2, - self.head_dim) - return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[ - ..., [-1], :] + fused_qkv = fused_qkv.unflatten(-1, (-1, self.head_dim)) + split_shape = (fused_qkv.size(-2) - 2, 1, 1) + return fused_qkv.split(split_shape, dim=-2) def _contiguous_batching_forward( self, hidden_states: torch.Tensor, - position_ids: torch.Tensor, alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, output_attentions: bool = False, ): # prepare inputs for continuous batch forwarding context = self.context.context - - history_lengths = context.history_lengths q_start_loc = context.q_start_loc q_seq_length = context.q_seq_length - history_lengths = q_seq_length.new_tensor(history_lengths) - kv_seq_length = q_seq_length + history_lengths - max_seq_len = q_seq_length.max().item() - - fused_qkv = self.query_key_value( - hidden_states) # [batch_size, seq_length, 3 x hidden_size] + history_lengths = context.history_lengths + kv_seq_length = context.kv_seq_length + max_seq_length = context.max_seq_length + block_offsets = context.block_offsets + position_ids_1d = context.position_ids_1d + + def __maybe_rotary_fn(query_states, key_states, value_states): + scaling_factor = 1.0 + inv_freq = self.maybe_rotary.inv_freq + query_states, key_states = fused_rotary_emb( + query_states[None], + key_states[None], + position_ids_1d[None], + inv_freq=inv_freq, + scaling_factor=scaling_factor, + out_q=query_states[None], + out_k=key_states[None]) + return query_states[0], key_states[0], value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + """rotary embedding func.""" + kv_seq_len = max_seq_length + max(history_lengths) + + cos, sin = self.rotary_emb(value_states.transpose(0, 1), + kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, context.position_ids, + position_ids_1d) + return query_states, key_states, value_states + + fused_qkv = self.query_key_value(hidden_states) # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, query_length, _, _ = query_layer.shape - - if isinstance(self.maybe_rotary, nn.Module): - position_ids_1d = self.context.context.position_ids_1d - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, - position_ids_1d) - - if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, kv_length, head_dim] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - - fill_kv_cache(key_layer.contiguous(), - value_layer.contiguous(), - past_key, - past_value, - q_start_loc, - q_seq_length, - block_offsets=context.block_offsets, - history_lengths=context.history_lengths, - context=context) - - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - attn_output = torch.empty_like(query_layer) - block_offsets = context.block_offsets - - if alibi is None: + query_layer = query_layer.flatten(0, 1) + key_layer = key_layer.flatten(0, 1) + value_layer = value_layer.flatten(0, 1) + if hasattr(self, 'maybe_rotary'): + query_layer, key_layer, value_layer = __maybe_rotary_fn( + query_layer, key_layer, value_layer) + elif hasattr(self, 'rotary_emb'): + query_layer, key_layer, value_layer = __rotary_emb_fn( + query_layer, key_layer, value_layer) + + past_key, past_value = layer_past + fill_kv_cache(key_layer.contiguous(), + value_layer.contiguous(), + past_key, + past_value, + q_start_loc, + q_seq_length, + block_offsets=block_offsets, + history_lengths=history_lengths, + context=context) + + attn_output = query_layer + + if not alibi: paged_attention_fwd(q=query_layer, k=past_key, v=past_value, @@ -283,9 +215,15 @@ def _contiguous_batching_forward( q_start_loc=q_start_loc, q_seqlens=q_seq_length, kv_seqlens=kv_seq_length, - max_seqlen=max_seq_len) + max_seqlen=max_seq_length) else: + num_heads_full = self.num_heads + head_offset = 0 + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + head_offset = self.num_heads // world_size * rank alibi_paged_attention_fwd(q=query_layer, k=past_key, v=past_value, @@ -294,17 +232,18 @@ def _contiguous_batching_forward( b_start_loc=q_start_loc, b_seq_len=q_seq_length, b_kv_seq_len=kv_seq_length, - max_input_len=max_seq_len, + max_input_len=max_seq_length, + head_offset=head_offset, + num_heads=num_heads_full, alibi_scale=self.inv_norm_factor) - attn_output = attn_output.reshape(batch_size, query_length, -1) - + attn_output = attn_output[None].flatten(-2, -1) output_tensor = self.dense(attn_output) if output_attentions: - return output_tensor, present, None + return output_tensor, layer_past, None else: - return output_tensor, present + return output_tensor, layer_past def forward( self, @@ -316,11 +255,8 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ): - position_ids = self.context.context.position_ids - return self._contiguous_batching_forward(hidden_states, position_ids, - alibi, attention_mask, - layer_past, head_mask, - use_cache, output_attentions) + return self._contiguous_batching_forward(hidden_states, alibi, + layer_past) class PatchedFalconMLP(nn.Module): @@ -350,41 +286,17 @@ class PatchedFalconModel(nn.Module): def _contiguous_batching_forward( self, input_ids: Optional[torch.LongTensor] = None, - # position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - # history_lengths = self.context.context.history_lengths - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # noqa - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache # noqa - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # noqa + output_attentions = False + use_cache = True use_alibi = getattr(self, 'use_alibi', getattr(self, 'alibi', False)) - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - 'You cannot specify both input_ids and inputs_embeds at the same time' # noqa - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - 'You have to specify either input_ids or inputs_embeds') - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N @@ -397,63 +309,35 @@ def _contiguous_batching_forward( hidden_states = inputs_embeds - # presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - # Compute alibi tensor: check build_alibi_tensor documentation - if use_alibi: - alibi = build_alibi_tensor(attention_mask, - self.num_heads, - dtype=hidden_states.dtype) - else: - alibi = None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) outputs = block( hidden_states, - # position_ids=position_ids, layer_past=layer_past, attention_mask=None, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, - alibi=alibi, + alibi=use_alibi, ) hidden_states = outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], ) - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states, ) - - if not return_dict: - return tuple(v for v in [ - hidden_states, past_key_values, all_hidden_states, - all_self_attentions - ] if v is not None) - return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attentions, + hidden_states=None, + attentions=None, ) def forward( self, input_ids: Optional[torch.LongTensor] = None, - # position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, @@ -466,12 +350,7 @@ def forward( ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: return self._contiguous_batching_forward( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) + input_ids=input_ids, past_key_values=past_key_values) class PatchedFalconForCausalLM(nn.Module): diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index b40f142279..67366a2cce 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -100,7 +100,7 @@ def __rotary_emb_fn(query_states, key_states, value_states): scaling_factor=scaling_factor, out_q=query_states[None], out_k=key_states[None]) - return query_states, key_states, value_states + return query_states[0], key_states[0], value_states query_states, key_states, value_states = __qkv_proj(hidden_states) diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 79037818a9..20e01402cd 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -245,7 +245,7 @@ def __rotary_emb_fn_438_fused(query_states, key_states, value_states): scaling_factor=scaling_factor, out_q=query_states[None], out_k=key_states[None]) - return query_states, key_states, value_states + return query_states[0], key_states[0], value_states def __rotary_emb_fn_438(query_states, key_states, value_states): rotary_name = type(self.rotary_emb).__name__ diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index ec12db19ac..30f463dccd 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -30,8 +30,6 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconAttention', 'modeling_falcon.FalconModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconModel', - 'modeling_falcon.FalconRotaryEmbedding': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconRotaryEmbedding', 'modeling_falcon.FalconMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.falcon.PatchedFalconMLP', 'modeling_falcon.FalconForCausalLM': diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index fb7c66607e..b4e9573718 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -212,6 +212,8 @@ def _register_hooks(): lambda mod, inputs, outputs: output_fn(outputs, device_mesh)) for name, child in model.named_children(): + if rank == 0: + logger.debug(f'Distribute module: <{name}>') new_child = _dist_model(child, rank, device_mesh) if new_child != child: model.register_module(name, child) From 7dd97fdb5b28054e6dba7d84ead265b95e8b0dd6 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Mon, 4 Mar 2024 14:56:37 +0800 Subject: [PATCH 07/17] Auto backend for pipeline and serve when backend is not set to pytorch explicitly (#1211) * add draft * update to use cfg * fix * enable gemma arch * resolve comments * add is_supported in each backend * add ut * fix ut --- .github/workflows/unit-test.yml | 2 +- lmdeploy/api.py | 16 ++++- lmdeploy/archs.py | 82 +++++++++++++++++++++ lmdeploy/cli/cli.py | 3 +- lmdeploy/cli/serve.py | 24 +++++-- lmdeploy/pytorch/supported_models.py | 92 ++++++++++++++++++++++++ lmdeploy/turbomind/supported_models.py | 88 +++++++++++++++++++++++ tests/test_lmdeploy/test_auto_backend.py | 77 ++++++++++++++++++++ 8 files changed, 374 insertions(+), 10 deletions(-) create mode 100644 lmdeploy/archs.py create mode 100644 lmdeploy/pytorch/supported_models.py create mode 100644 lmdeploy/turbomind/supported_models.py create mode 100644 tests/test_lmdeploy/test_auto_backend.py diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index c4915e886d..590a019689 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -74,7 +74,7 @@ jobs: make -j$(nproc) && make install - name: Install lmdeploy run: | - python3 -m pip install pynvml packaging protobuf transformers_stream_generator transformers==4.33.0 + python3 -m pip install pynvml packaging protobuf transformers_stream_generator # manually install flash attn python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl python3 -m pip install -r requirements.txt -r requirements/test.txt diff --git a/lmdeploy/api.py b/lmdeploy/api.py index 34f2be9b72..78607bfd73 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -2,6 +2,7 @@ import os from typing import List, Literal, Optional, Union +from .archs import autoget_backend_config from .messages import PytorchEngineConfig, TurbomindEngineConfig from .model import ChatTemplateConfig @@ -31,7 +32,7 @@ def pipeline(model_path: str, model_name (str): needed when model_path is a pytorch model on huggingface.co, such as "internlm/internlm-chat-7b", "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. - backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend + backend_config (TurbomindEngineConfig | PytorchEngineConfig): backend config instance. Default to None. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. @@ -49,8 +50,13 @@ def pipeline(model_path: str, from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') logger.setLevel(log_level) + + if type(backend_config) is not PytorchEngineConfig: + # set auto backend mode + backend_config = autoget_backend_config(model_path, backend_config) backend = 'pytorch' if type( backend_config) is PytorchEngineConfig else 'turbomind' + logger.info(f'Using {backend} engine') if 'tp' in kwargs: logger.warning( 'The argument "tp" is deprecated and will be removed soon. ' @@ -101,7 +107,7 @@ def serve(model_path: str, "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on. backend (str): either `turbomind` or `pytorch` backend. Default to `turbomind` backend. - backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend + backend_config (TurbomindEngineConfig | PytorchEngineConfig): backend config instance. Default to none. chat_template_config (ChatTemplateConfig): chat template configuration. Default to None. @@ -126,6 +132,12 @@ def serve(model_path: str, from lmdeploy.serve.openai.api_client import APIClient from lmdeploy.serve.openai.api_server import serve + + if type(backend_config) is not PytorchEngineConfig: + # set auto backend mode + backend_config = autoget_backend_config(model_path, backend_config) + backend = 'pytorch' if type( + backend_config) is PytorchEngineConfig else 'turbomind' if 'tp' in kwargs: tp = kwargs['tp'] kwargs.pop('tp') diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py new file mode 100644 index 0000000000..7a8b7f817e --- /dev/null +++ b/lmdeploy/archs.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Literal, Optional, Union + +from .messages import PytorchEngineConfig, TurbomindEngineConfig +from .utils import get_logger + +logger = get_logger('lmdeploy') + + +def autoget_backend(model_path: str) -> Union[Literal['turbomind', 'pytorch']]: + """Get backend type in auto backend mode. + + Args: + model_path (str): the path of a model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + + Returns: + str: the backend type. + """ + from lmdeploy.pytorch.supported_models import \ + is_supported as is_supported_pytorch + + pytorch_has, turbomind_has = False, False + try: + from lmdeploy.turbomind.supported_models import \ + is_supported as is_supported_turbomind + turbomind_has = is_supported_turbomind(model_path) + except ImportError: + logger.warning( + 'Lmdeploy with turbomind engine is not installed correctly. ' + 'You may need to install lmdeploy from pypi or build from source ' + 'for turbomind engine.') + + pytorch_has = is_supported_pytorch(model_path) + + if not (pytorch_has or turbomind_has): + logger.warning(f'{model_path} is not explicitly supported by lmdeploy.' + f' Try to run with lmdeploy pytorch engine.') + backend = 'turbomind' if turbomind_has else 'pytorch' + return backend + + +def autoget_backend_config( + model_path: str, + backend_config: Optional[Union[PytorchEngineConfig, + TurbomindEngineConfig]] = None +) -> Union[PytorchEngineConfig, TurbomindEngineConfig]: + """Get backend config automatically. + + Args: + model_path (str): The input model path. + backend_config (TurbomindEngineConfig | PytorchEngineConfig): The + input backend config. Default to None. + + Returns: + (PytorchEngineConfig | TurbomindEngineConfig): The auto-determined + backend engine config. + """ + from dataclasses import asdict + + backend = autoget_backend(model_path) + if backend == 'pytorch': + config = PytorchEngineConfig() + else: + config = TurbomindEngineConfig() + if backend_config is not None: + data = asdict(backend_config) + for k, v in data.items(): + if v and hasattr(config, k): + setattr(config, k, v) + return config diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index 4a4a73046c..e4aa405653 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -107,7 +107,8 @@ def list(args): if engine == 'pytorch': model_names = [ 'llama', 'llama2', 'internlm', 'internlm2', 'baichuan2', - 'chatglm2', 'falcon', 'yi', 'mistral', 'qwen1.5', 'gemma' + 'chatglm2', 'falcon', 'yi', 'mistral', 'mixtral', 'qwen1.5', + 'gemma', 'deepseek' ] elif engine == 'turbomind': from lmdeploy.model import MODELS diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 858584ae7d..7b01afcb3e 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -189,10 +189,17 @@ def add_parser_triton_client(): @staticmethod def gradio(args): """Serve LLMs with web UI using gradio.""" + from lmdeploy.archs import autoget_backend + from lmdeploy.messages import (PytorchEngineConfig, + TurbomindEngineConfig) from lmdeploy.model import ChatTemplateConfig from lmdeploy.serve.gradio.app import run - if args.backend == 'pytorch': - from lmdeploy.messages import PytorchEngineConfig + backend = args.backend + + if backend != 'pytorch' and ':' not in args.model_path_or_server: + # set auto backend mode + backend = autoget_backend(args.model_path_or_server) + if backend == 'pytorch': backend_config = PytorchEngineConfig( tp=args.tp, model_name=args.model_name, @@ -200,7 +207,6 @@ def gradio(args): cache_max_entry_count=args.cache_max_entry_count, session_len=args.session_len) else: - from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig( model_name=args.model_name, tp=args.tp, @@ -217,16 +223,22 @@ def gradio(args): run(args.model_path_or_server, server_name=args.server_name, server_port=args.server_port, - backend=args.backend, + backend=backend, backend_config=backend_config, chat_template_config=chat_template_config) @staticmethod def api_server(args): """Serve LLMs with restful api using fastapi.""" + from lmdeploy.archs import autoget_backend from lmdeploy.model import ChatTemplateConfig from lmdeploy.serve.openai.api_server import serve as run_api_server - if args.backend == 'pytorch': + backend = args.backend + if backend != 'pytorch': + # set auto backend mode + backend = autoget_backend(args.model_path) + + if backend == 'pytorch': from lmdeploy.messages import PytorchEngineConfig backend_config = PytorchEngineConfig( tp=args.tp, @@ -250,7 +262,7 @@ def api_server(args): meta_instruction=args.meta_instruction, capability=args.cap) run_api_server(args.model_path, - backend=args.backend, + backend=backend, backend_config=backend_config, chat_template_config=chat_template_config, server_name=args.server_name, diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py new file mode 100644 index 0000000000..1941291573 --- /dev/null +++ b/lmdeploy/pytorch/supported_models.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from transformers import AutoConfig + +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +_SUPPORTED_ARCHS = dict( + # baichuan-7b + BaiChuanForCausalLM=False, + # baichuan2-7b, baichuan-13b, baichuan2-13b + BaichuanForCausalLM=True, + # chatglm2-6b, chatglm3-6b + ChatGLMModel=True, + # deepseek-moe + DeepseekForCausalLM=True, + # falcon-7b + FalconForCausalLM=True, + # gemma-7b + GemmaForCausalLM=True, + # internlm + InternLMForCausalLM=True, + # internlm2 + InternLM2ForCausalLM=True, + # internlm-xcomposer + InternLMXComposerForCausalLM=False, + # internlm2-xcomposer + InternLM2XComposerForCausalLM=False, + # llama, llama2, alpaca, vicuna, codellama, ultracm, yi, + # deepseek-coder, deepseek-llm + LlamaForCausalLM=True, + # Mistral-7B + MistralForCausalLM=True, + # Mixtral-8x7B + MixtralForCausalLM=True, + # Qwen 7B-72B, Qwen-VL-7B + QWenLMHeadModel=False, + # Qwen1.5 7B-72B + Qwen2ForCausalLM=True, +) + + +def is_supported(model_path: str): + """Check whether supported by pytorch engine. + + Args: + model_path (str): the path of a model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + Returns: + support_by_torch (bool): Whether input model is supported by pytorch engine + """ # noqa: E501 + import os + + support_by_torch = False + + triton_model_path = os.path.join(model_path, 'triton_models') + if os.path.exists(triton_model_path): + logger.warning(f'{model_path} seems to be a turbomind workspace, ' + 'which can only be ran with turbomind engine.') + else: + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if hasattr(cfg, 'architectures'): + arch = cfg.architectures[0] + elif hasattr(cfg, + 'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map: + arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1] + else: + raise RuntimeError( + f'Could not find model architecture from config: {cfg}') + + if arch in _SUPPORTED_ARCHS: + support_by_torch = _SUPPORTED_ARCHS[arch] + # special cases + if arch == 'BaichuanForCausalLM': + # baichuan-13B not supported by pytorch + if cfg.num_attention_heads == 40 and cfg.vocab_size == 64000: + support_by_torch = False + + return support_by_torch diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py new file mode 100644 index 0000000000..99dd4400c4 --- /dev/null +++ b/lmdeploy/turbomind/supported_models.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from transformers import AutoConfig + +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +_SUPPORTED_ARCHS = dict( + # baichuan-7b + BaiChuanForCausalLM=True, + # baichuan2-7b, baichuan-13b, baichuan2-13b + BaichuanForCausalLM=True, + # chatglm2-6b, chatglm3-6b + ChatGLMModel=False, + # deepseek-moe + DeepseekForCausalLM=False, + # falcon-7b + FalconForCausalLM=False, + # gemma-7b + GemmaForCausalLM=False, + # internlm + InternLMForCausalLM=True, + # internlm2 + InternLM2ForCausalLM=True, + # internlm-xcomposer + InternLMXComposerForCausalLM=True, + # internlm2-xcomposer + InternLM2XComposerForCausalLM=False, + # llama, llama2, alpaca, vicuna, codellama, ultracm, yi, + # deepseek-coder, deepseek-llm + LlamaForCausalLM=True, + # Mistral-7B + MistralForCausalLM=False, + # Mixtral-8x7B + MixtralForCausalLM=False, + # Qwen 7B-72B, Qwen-VL-7B + QWenLMHeadModel=True, + # Qwen1.5 7B-72B + Qwen2ForCausalLM=False) + + +def is_supported(model_path: str): + """Check whether supported by turbomind engine. + + Args: + model_path (str): the path of a model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download from + ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + Returns: + support_by_turbomind (bool): Whether input model is supported by turbomind engine + """ # noqa: E501 + import os + + support_by_turbomind = False + triton_model_path = os.path.join(model_path, 'triton_models') + if os.path.exists(triton_model_path): + support_by_turbomind = True + else: + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + if hasattr(cfg, 'architectures'): + arch = cfg.architectures[0] + elif hasattr(cfg, + 'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map: + arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1] + else: + raise RuntimeError( + f'Could not find model architecture from config: {cfg}') + + if arch in _SUPPORTED_ARCHS: + support_by_turbomind = _SUPPORTED_ARCHS[arch] + # special cases + if arch == 'BaichuanForCausalLM': + num_attn_head = cfg.num_attention_heads + if num_attn_head == 40: + # baichuan-13B, baichuan2-13B not supported by turbomind + support_by_turbomind = False + return support_by_turbomind diff --git a/tests/test_lmdeploy/test_auto_backend.py b/tests/test_lmdeploy/test_auto_backend.py new file mode 100644 index 0000000000..b96ca4e3cb --- /dev/null +++ b/tests/test_lmdeploy/test_auto_backend.py @@ -0,0 +1,77 @@ +import os +import tempfile + +import numpy as np +import pytest + + +class TestAutoBackend: + + @pytest.fixture + def turbomind_workspace(self): + workspace = tempfile.TemporaryDirectory( + 'internlm-chat-7b-turbomind').name + os.makedirs(os.path.join(workspace, 'triton_models'), exist_ok=True) + return workspace + + @pytest.fixture + def models(self): + # example models to test + # format (model_path, is_pytorch_supported, is_turbomind_supported) + models = [ + ('baichuan-inc/Baichuan-7B', False, True), + ('baichuan-inc/Baichuan2-7B-Chat', True, True), + ('baichuan-inc/Baichuan-13B-Chat', False, False), + ('baichuan-inc/Baichuan2-13B-Chat', True, False), + ('internlm/internlm-chat-7b', True, True), + ('internlm/internlm2-chat-7b', True, True), + ('internlm/internlm-xcomposer2-7b', False, False), + ('internlm/internlm-xcomposer-7b', False, True), + ('THUDM/chatglm2-6b', True, False), + ('THUDM/chatglm3-6b', True, False), + ('deepseek-ai/deepseek-moe-16b-chat', True, False), + ('tiiuae/falcon-7b-instruct', True, False), + ('01-ai/Yi-34B-Chat', True, True), + ('codellama/CodeLlama-7b-Instruct-hf', True, True), + ('mistralai/Mistral-7B-Instruct-v0.1', True, False), + ('mistralai/Mixtral-8x7B-Instruct-v0.1', True, False), + ('Qwen/Qwen-7B-Chat', False, True), + ('Qwen/Qwen-VL-Chat', False, True), + ('Qwen/Qwen1.5-4B-Chat', True, False), + ] + return models + + def test_pytorch_is_suppored(self, turbomind_workspace, models): + from lmdeploy.pytorch.supported_models import is_supported + assert is_supported(turbomind_workspace) is False + for m, flag, _ in models: + assert is_supported(m) is flag + + def test_turbomind_is_suppored(self, turbomind_workspace, models): + from lmdeploy.turbomind.supported_models import is_supported + assert is_supported(turbomind_workspace) is True + for m, _, flag in models: + assert is_supported(m) is flag + + def test_autoget_backend(self, turbomind_workspace, models): + from lmdeploy.archs import autoget_backend + assert autoget_backend(turbomind_workspace) == 'turbomind' + n = len(models) + choices = np.random.choice(n, n // 2, replace=False) + for i in choices: + model, is_support_pytorch, is_support_turbomind = models[i] + target = 'turbomind' if is_support_turbomind else 'pytorch' + backend = autoget_backend(model) + assert backend == target + + def test_autoget_backend_config(self, turbomind_workspace): + from lmdeploy.archs import autoget_backend_config + from lmdeploy.messages import (PytorchEngineConfig, + TurbomindEngineConfig) + assert type(autoget_backend_config( + turbomind_workspace)) is TurbomindEngineConfig + assert type(autoget_backend_config( + 'internlm/internlm-chat-7b')) is TurbomindEngineConfig + assert type( + autoget_backend_config( + 'mistralai/Mistral-7B-Instruct-v0.1')) is PytorchEngineConfig From dd44e7f5357e355276d26b1ec2898e057faf7d82 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 5 Mar 2024 09:55:16 +0800 Subject: [PATCH 08/17] fix multinomial sampling (#1239) Co-authored-by: grimoire --- lmdeploy/pytorch/kernels/multinomial_sampling.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/kernels/multinomial_sampling.py b/lmdeploy/pytorch/kernels/multinomial_sampling.py index 476787fb8c..44ed759052 100644 --- a/lmdeploy/pytorch/kernels/multinomial_sampling.py +++ b/lmdeploy/pytorch/kernels/multinomial_sampling.py @@ -21,8 +21,8 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, offset = tl.load(Offsets + off, mask=off_mask).to(tl.int32) samp = tl.rand(seed, offset)[:, None] - acc = tl.zeros((BLOCK, ), dtype=Scores.dtype.element_ty) - output = tl.full((BLOCK, ), -1, dtype=Outputs.dtype.element_ty) + acc = tl.zeros((BLOCK, ), dtype=tl.float32) + output = tl.load(Indices + off * stride_ib, mask=off_mask) for b_idx in range(0, num_tokens, BLOCK_N): s_off = b_idx + n_off @@ -30,9 +30,9 @@ def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, scores = tl.load(Scores + off[:, None] * stride_sb + s_off[None, :] * stride_st, mask=s_mask, - other=0.0) - cum_scores = acc[:, None] + tl.cumsum(scores, 1).to(acc.dtype) - acc += tl.sum(scores, 1).to(acc.dtype) + other=0.0).to(acc.dtype) + cum_scores = acc[:, None] + tl.cumsum(scores, 1) + acc += tl.sum(scores, 1) pre_cum_scores = cum_scores - scores valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores) @@ -75,7 +75,7 @@ def __kernel_meta(): assert indices.dim() == 2 assert indices.size() == scores.size() - outputs = indices.new_empty(batch_size) + outputs = indices[:, 0].clone() BLOCK = 32 BLOCK_N = 64 @@ -96,5 +96,5 @@ def __kernel_meta(): BLOCK=BLOCK, BLOCK_N=BLOCK_N, **kernel_meta) - torch.cuda.synchronize() + return outputs From d0a2dab67b74435837a45d4a25f54ad177102b5f Mon Sep 17 00:00:00 2001 From: zhyncs Date: Tue, 5 Mar 2024 11:10:54 +0800 Subject: [PATCH 09/17] update doc index (#1241) --- docs/en/index.rst | 1 + docs/zh_cn/index.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/index.rst b/docs/en/index.rst index 0ba7b86346..daccaae535 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -67,6 +67,7 @@ Welcome to LMDeploy's tutorials! advance/pytorch_new_model.md advance/long_context.md + advance/debug_turbomind.md serving/qos.md .. toctree:: diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 3f8272f01b..e89bfa661b 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -69,6 +69,7 @@ advance/pytorch_new_model.md advance/long_context.md + advance/debug_turbomind.md serving/qos.md .. toctree:: From c5b0a3143a807f28e8a19321354a2fb94cdf32cc Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 5 Mar 2024 11:36:29 +0800 Subject: [PATCH 10/17] remove unused kernel in pytorch engine (#1237) --- lmdeploy/pytorch/engine/engine.py | 14 + lmdeploy/pytorch/kernels/__init__.py | 5 +- .../pytorch/kernels/biased_pagedattention.py | 240 ------------------ .../pytorch/kernels/flashattention_nopad.py | 199 --------------- tests/pytorch/kernel/test_paged_attention.py | 30 --- 5 files changed, 15 insertions(+), 473 deletions(-) delete mode 100644 lmdeploy/pytorch/kernels/biased_pagedattention.py delete mode 100644 lmdeploy/pytorch/kernels/flashattention_nopad.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index fe5bf7e0c2..c59bfed29b 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -834,6 +834,20 @@ async def _add_messages(session_ids, token_ids): output_token_len = [len(token_ids) for token_ids in output_token_ids] return (status, output_token_ids, output_token_len) + def batched_infer(self, + session_ids: List[int], + token_ids: List[List[int]] = None, + gen_config: EngineGenerationConfig = None, + adapter_names: List[str] = None, + keep_cache: bool = False): + """batched infer.""" + coro = self.async_batched_infer(session_ids, + token_ids, + gen_config=gen_config, + adapter_names=adapter_names, + keep_cache=keep_cache) + return self.req_sender.run_until_complete(coro) + def decode(self, input_ids, steps: List[int] = None, diff --git a/lmdeploy/pytorch/kernels/__init__.py b/lmdeploy/pytorch/kernels/__init__.py index a1e2ead436..31d2e4c0d8 100644 --- a/lmdeploy/pytorch/kernels/__init__.py +++ b/lmdeploy/pytorch/kernels/__init__.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alibi_pagedattention import alibi_paged_attention_fwd from .apply_rotary_pos_emb import apply_rotary_pos_emb -from .biased_pagedattention import biased_paged_attention_fwd from .fill_kv_cache import fill_kv_cache -from .flashattention_nopad import context_attention_fwd from .fused_rotary_emb import fused_rotary_emb from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd @@ -11,8 +9,7 @@ from .rms_norm import rms_norm __all__ = [ - 'apply_rotary_pos_emb', 'context_attention_fwd', 'fused_rotary_emb', - 'paged_attention_fwd', 'biased_paged_attention_fwd', + 'apply_rotary_pos_emb', 'fused_rotary_emb', 'paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache', 'multinomial_sampling', 'rms_norm', 'rerope_attention_fwd' ] diff --git a/lmdeploy/pytorch/kernels/biased_pagedattention.py b/lmdeploy/pytorch/kernels/biased_pagedattention.py deleted file mode 100644 index 1270c17e7b..0000000000 --- a/lmdeploy/pytorch/kernels/biased_pagedattention.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modify from: https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl -from torch import Tensor - -assert triton.__version__ >= '2.1.0' - -_NV_CAP = torch.cuda.get_device_capability() -if _NV_CAP[0] >= 8: - - @triton.jit - def _convert_pv(p, v): - """convert pv.""" - p = p.to(v.dtype) - return p, v -else: - - @triton.jit - def _convert_pv(p, v): - """convert pv.""" - v = v.to(p.dtype) - return p, v - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - Bias, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_kvlen, - Block_offsets, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_biasbs, - stride_biash, - stride_biasq, - stride_biask, - stride_obs, - stride_oh, - stride_od, - stride_boffb, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """biased paged attention kernel.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_kv_len = tl.load(B_kvlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - off_bias = (cur_batch * stride_biasbs + cur_head * stride_biash + - offs_m[:, None] * stride_biasq + - offs_n[None, :] * stride_biask) - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - bias_ptrs = Bias + off_bias - - block_offset_ptrs = Block_offsets + cur_batch * stride_boffb - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - start_block_id = start_n // BLOCK_N - b_offset = tl.load(block_offset_ptrs + start_block_id) - - # -- compute qk ---- - k = tl.load( - k_ptrs + b_offset * BLOCK_N * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_kv_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - bias = tl.load( - bias_ptrs + start_n, - mask=(start_n + offs_n[None, :]) < cur_batch_kv_len - and (offs_m[:, None] < cur_batch_seq_len), - other=-1e30, - ) - qk += bias - - # -- compute p, m_i and l_i - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.exp(qk - m_i_new[:, None]) - alpha = tl.exp(m_i - m_i_new) - l_i_new = alpha * l_i + tl.sum(p, 1) - # -- update output accumulator -- - # scale acc - acc = acc * alpha[:, None] - # update acc - v = tl.load( - v_ptrs + b_offset * BLOCK_N * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_kv_len, - other=0.0, - ) - - p, v = _convert_pv(p, v) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - -@torch.no_grad() -def biased_paged_attention_fwd( - q: Tensor, - k: Tensor, - v: Tensor, - bias: Tensor, - o: Tensor, - block_offsets: Tensor, - b_start_loc: Tensor, - b_seq_len: Tensor, - b_kv_seq_len: Tensor, - max_input_len: int, - BLOCK: int = 64, -): - """Paged attention forward with custom bias. - - Args: - q (Tensor): Query state. - k (Tensor): Key state caches. - v (Tensor): Value state caches. - bias (Tensor): Bias of the QK. - o (Tensor): Output state. - block_offsets (Tensor): The block offset of key and value. - b_start_loc (Tensor): Start token location of each data in batch. - b_seq_len (Tensor): Query length for each data in batch. - b_kv_seq_len (Tensor): Key/Value length for each data in batch. - max_input_len (int): The max input length. - BLOCK (int): The kernel block size. - """ - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - assert bias.dtype == torch.float32 - - if bias.dim() == 2: - bias = bias.unsqueeze(0) - - if bias.dim() == 3: - bias = bias.unsqueeze(1) - - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[-2] - kv_group_num = q.shape[-2] // k[0].shape[-2] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - bias_head_stride = 0 if bias.size(1) == 1 else bias.stride(-3) - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - bias, - sm_scale, - b_start_loc, - b_seq_len, - b_kv_seq_len, - block_offsets, - o, - q.stride(-3), - q.stride(-2), - q.stride(-1), - k.stride(-3), - k.stride(-2), - k.stride(-1), - v.stride(-3), - v.stride(-2), - v.stride(-1), - bias.stride(-4), - bias_head_stride, - bias.stride(-2), - bias.stride(-1), - o.stride(-3), - o.stride(-2), - o.stride(-1), - block_offsets.stride(0), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) diff --git a/lmdeploy/pytorch/kernels/flashattention_nopad.py b/lmdeploy/pytorch/kernels/flashattention_nopad.py deleted file mode 100644 index 567a744968..0000000000 --- a/lmdeploy/pytorch/kernels/flashattention_nopad.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modify from: https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl -from torch import Tensor - -assert triton.__version__ >= '2.1.0' - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_kv_start_loc, - B_kvlen, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """flash attention forward triton kernel.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_kv_len = tl.load(B_kvlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_kv_start_index = tl.load(B_kv_start_loc + cur_batch) - history_len = cur_batch_kv_len - cur_batch_seq_len - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_kv_len, - other=0.0, - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # NOTE: inf - inf = nan, and nan will leads to error - qk = tl.where( - (history_len + offs_m[:, None]) >= (start_n + offs_n[None, :]), - qk, - float(-1e30), - ) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_kv_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - -@torch.no_grad() -def context_attention_fwd( - q: Tensor, - k: Tensor, - v: Tensor, - o: Tensor, - b_start_loc: Tensor, - b_seq_len: Tensor, - b_kv_start_loc: Tensor, - b_kv_seq_len: Tensor, - max_input_len: int, - BLOCK: int = 64, -): - """Context Attention forward. - - Args: - q (Tensor): Query state. - k (Tensor): Key state caches. - v (Tensor): Value state caches. - o (Tensor): Output state. - b_start_loc (Tensor): Start token location of each data in batch. - b_seq_len (Tensor): Query length for each data in batch. - b_kv_start_loc (Tensor): Start token location of kv in each data - in batch. - b_kv_seq_len (Tensor): Key/Value length for each data in batch. - max_input_len (int): The max input length. - BLOCK (int): The kernel block size. - """ - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - b_kv_start_loc, - b_kv_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index 55175aee5f..3cd208f052 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -248,36 +248,6 @@ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, max_seqlen=max_seq_len) torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) - @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], - indirect=True) - @pytest.mark.parametrize(['seq_lens', 'history_lens'], - [([30, 50, 70, 90], [50, 40, 30, 20])], - indirect=True) - @pytest.mark.parametrize('block_size', [16], indirect=True) - def test_biased_paged_attention(self, conti_q, blocked_kv, block_offsets, - start_loc, seq_lens, history_lens, - block_size, mask, conti_gt): - from lmdeploy.pytorch.kernels import biased_paged_attention_fwd - kv_seq_lens = seq_lens + history_lens - max_seq_len = seq_lens.max().item() - - blocked_k, blocked_v = blocked_kv - out = torch.empty_like(conti_q) - - biased_paged_attention_fwd(conti_q, - blocked_k, - blocked_v, - mask, - out, - block_offsets=block_offsets, - b_start_loc=start_loc, - b_seq_len=seq_lens, - b_kv_seq_len=kv_seq_lens, - max_input_len=max_seq_len, - BLOCK=block_size) - - torch.testing.assert_close(out, conti_gt) - @pytest.fixture def win_size(self, request): yield request.param From 4bec832028bba10f2216137b251be56b905728d5 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 5 Mar 2024 11:56:57 +0800 Subject: [PATCH 11/17] reduce torchengine prefill mem usage (#1240) * reduce mem usage * remove pdb * del to pop --- lmdeploy/pytorch/engine/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index c59bfed29b..e4b242fc7f 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -673,11 +673,13 @@ async def __long_context_forward(inputs): if token_count == 0 and slen > max_prefill_token_num: tmp_out = await __long_context_single_forward(inputs, idx) logits_gather.gather(tmp_out) + tmp_out.pop('logits', None) idx += 1 elif token_count + slen > max_prefill_token_num: tmp_out = await __long_context_batched_forward( inputs, indices[0], idx) logits_gather.gather(tmp_out) + tmp_out.pop('logits', None) indices = [] token_count = 0 else: From c5f4014842fc817a3bdf3a57cee6d76c089fe629 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Tue, 5 Mar 2024 16:38:51 +0800 Subject: [PATCH 12/17] bump version to v0.2.5 (#1235) --- lmdeploy/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/version.py b/lmdeploy/version.py index d45eba82e4..277e7502ec 100644 --- a/lmdeploy/version.py +++ b/lmdeploy/version.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Tuple -__version__ = '0.2.4' +__version__ = '0.2.5' short_version = __version__ From 0d027a5806c71c04ccb860bb53953554fbcb5307 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 5 Mar 2024 20:12:25 +0800 Subject: [PATCH 13/17] fix config for readthedocs (#1245) * add config * update * move to different docs * update --- .readthedocs.yaml => docs/en/.readthedocs.yaml | 5 +++++ docs/zh_cn/.readthedocs.yaml | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) rename .readthedocs.yaml => docs/en/.readthedocs.yaml (81%) create mode 100644 docs/zh_cn/.readthedocs.yaml diff --git a/.readthedocs.yaml b/docs/en/.readthedocs.yaml similarity index 81% rename from .readthedocs.yaml rename to docs/en/.readthedocs.yaml index 05ec15cca3..525ef5f7a3 100644 --- a/.readthedocs.yaml +++ b/docs/en/.readthedocs.yaml @@ -7,6 +7,11 @@ build: tools: python: "3.8" + +sphinx: + configuration: docs/en/conf.py + + python: install: - requirements: requirements/docs.txt diff --git a/docs/zh_cn/.readthedocs.yaml b/docs/zh_cn/.readthedocs.yaml new file mode 100644 index 0000000000..9f94662947 --- /dev/null +++ b/docs/zh_cn/.readthedocs.yaml @@ -0,0 +1,18 @@ +version: 2 + +formats: all + +build: + os: "ubuntu-22.04" + tools: + python: "3.8" + + +sphinx: + configuration: docs/zh_cn/conf.py + + +python: + install: + - requirements: requirements/docs.txt + - requirements: requirements/readthedocs.txt From 40b60caa2a135e70974c60d7949592d8b3d4862d Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Wed, 6 Mar 2024 11:06:15 +0800 Subject: [PATCH 14/17] update badges in README (#1243) * update badge in README * update * ignore twitter url checking --- .github/md-link-config.json | 3 +++ README.md | 17 ++++++++++------- README_zh-CN.md | 17 ++++++++++------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/.github/md-link-config.json b/.github/md-link-config.json index 469ac707a6..daf5a28f28 100644 --- a/.github/md-link-config.json +++ b/.github/md-link-config.json @@ -17,6 +17,9 @@ }, { "pattern": "^http://localhost" + }, + { + "pattern": "^https://twitter.com" } ], "httpHeaders": [ diff --git a/README.md b/README.md index 9f4f0d66b5..8fcf63330e 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,23 @@
-[![docs](https://img.shields.io/badge/docs-latest-blue)](https://lmdeploy.readthedocs.io/en/latest/) -[![badge](https://github.com/InternLM/lmdeploy/workflows/lint/badge.svg)](https://github.com/InternLM/lmdeploy/actions) [![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy) +![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy) [![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE) [![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) +[📘Documentation](https://lmdeploy.readthedocs.io/en/latest/) | +[🛠️Quick Start](https://lmdeploy.readthedocs.io/en/latest/get_started.html) | +[🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose) + English | [简体中文](README_zh-CN.md) -
+👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://r.vansin.top/?r=internwx) +[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm) +[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d) -

- 👋 join us on Twitter, Discord and WeChat -

+ ______________________________________________________________________ @@ -22,7 +25,7 @@ ______________________________________________________________________
2024 - +- \[2024/02\] Support Qwen 1.5, Gemma, Mistral, Mixtral, Deepseek-MOE and so on. - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE) seamless integration with [LMDeploy Serving Service](./docs/en/serving/restful_api.md). - \[2024/01\] Support for multi-model, multi-machine, multi-card inference services. For usage instructions, please refer to [here](./docs/en/serving/proxy_server.md) - \[2024/01\] Support [PyTorch inference engine](./docs/en/inference/pytorch.md), developed entirely in Python, helping to lower the barriers for developers and enable rapid experimentation with new features and technologies. diff --git a/README_zh-CN.md b/README_zh-CN.md index 51155b819a..de56c0cae9 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -1,20 +1,23 @@
-[![docs](https://img.shields.io/badge/docs-latest-blue)](https://lmdeploy.readthedocs.io/zh-cn/latest/) -[![badge](https://github.com/InternLM/lmdeploy/workflows/lint/badge.svg)](https://github.com/InternLM/lmdeploy/actions) [![PyPI](https://img.shields.io/pypi/v/lmdeploy)](https://pypi.org/project/lmdeploy) +![PyPI - Downloads](https://img.shields.io/pypi/dm/lmdeploy) [![license](https://img.shields.io/github/license/InternLM/lmdeploy.svg)](https://github.com/InternLM/lmdeploy/tree/main/LICENSE) [![issue resolution](https://img.shields.io/github/issues-closed-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) [![open issues](https://img.shields.io/github/issues-raw/InternLM/lmdeploy)](https://github.com/InternLM/lmdeploy/issues) +[📘Documentation](https://lmdeploy.readthedocs.io/zh-cn/latest/) | +[🛠️Quick Start](https://lmdeploy.readthedocs.io/zh-cn/latest/get_started.html) | +[🤔Reporting Issues](https://github.com/InternLM/lmdeploy/issues/new/choose) + [English](README.md) | 简体中文 -
+👋 join us on [![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=wechat&label=WeChat)](https://r.vansin.top/?r=internwx) +[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=twitter&label=Twitter)](https://twitter.com/intern_lm) +[![Static Badge](https://img.shields.io/badge/-grey?style=social&logo=discord&label=Discord)](https://discord.gg/xa29JuW87d) -

- 👋 join us on Twitter, Discord and WeChat -

+ ______________________________________________________________________ @@ -22,7 +25,7 @@ ______________________________________________________________________
2024 - +- \[2024/02\] 支持 Qwen 1.5、Gemma、Mistral、Mixtral、Deepseek-MOE 等模型 - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE) 发布,支持无缝接入[LMDeploy Serving Service](./docs/zh_cn/serving/restful_api.md) - \[2024/01\] 支持多模型、多机、多卡推理服务。使用方法请参考[此处](./docs/zh_cn/serving/proxy_server.md) - \[2024/01\] 增加 [PyTorch 推理引擎](./docs/zh_cn/inference/pytorch.md),作为 TurboMind 引擎的补充。帮助降低开发门槛,和快速实验新特性、新技术 From 278297a55f2b40c206bef311a7fd4229d9e5e789 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 6 Mar 2024 11:07:21 +0800 Subject: [PATCH 15/17] Hide qos functions from swagger UI if not applied (#1238) --- lmdeploy/serve/openai/api_server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 2ecd1cdc96..97ea2d9939 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1052,6 +1052,11 @@ def serve(model_path: str, VariableInterface.qos_engine.start() except FileNotFoundError: VariableInterface.qos_engine = None + else: + # hide qos functions if not applied + for i in range(len(app.router.routes)): + if 'qos' in app.router.routes[i].path: + app.router.routes[i].include_in_schema = False for i in range(3): print( From 9149a49eae8e02298203c4d3283b57c0f1883886 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Wed, 6 Mar 2024 11:57:22 +0800 Subject: [PATCH 16/17] Update serving guide including api_server and gradio (#1248) * update gradio.md * update restful_api.md * update get_started.md * change title * change order * change title * fix linting * fix linting * fix linting * fix typo --- .github/md-link-config.json | 6 ++ README.md | 1 + README_zh-CN.md | 1 + docs/en/get_started.md | 16 +-- docs/en/index.rst | 2 +- docs/en/serving/gradio.md | 25 ++++- docs/en/serving/restful_api.md | 171 ++++++++++++++++++------------ docs/zh_cn/get_started.md | 16 +-- docs/zh_cn/index.rst | 2 +- docs/zh_cn/serving/gradio.md | 25 ++++- docs/zh_cn/serving/restful_api.md | 170 +++++++++++++++++------------ 11 files changed, 267 insertions(+), 168 deletions(-) diff --git a/.github/md-link-config.json b/.github/md-link-config.json index daf5a28f28..3b9bca0dcc 100644 --- a/.github/md-link-config.json +++ b/.github/md-link-config.json @@ -20,6 +20,12 @@ }, { "pattern": "^https://twitter.com" + }, + { + "pattern": "^https://platform.openai.com" + }, + { + "pattern": "^http://0.0.0.0" } ], "httpHeaders": [ diff --git a/README.md b/README.md index 8fcf63330e..5cf1699902 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ______________________________________________________________________
2024 + - \[2024/02\] Support Qwen 1.5, Gemma, Mistral, Mixtral, Deepseek-MOE and so on. - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE) seamless integration with [LMDeploy Serving Service](./docs/en/serving/restful_api.md). - \[2024/01\] Support for multi-model, multi-machine, multi-card inference services. For usage instructions, please refer to [here](./docs/en/serving/proxy_server.md) diff --git a/README_zh-CN.md b/README_zh-CN.md index de56c0cae9..b7a0e61a69 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -25,6 +25,7 @@ ______________________________________________________________________
2024 + - \[2024/02\] 支持 Qwen 1.5、Gemma、Mistral、Mixtral、Deepseek-MOE 等模型 - \[2024/01\] [OpenAOE](https://github.com/InternLM/OpenAOE) 发布,支持无缝接入[LMDeploy Serving Service](./docs/zh_cn/serving/restful_api.md) - \[2024/01\] 支持多模型、多机、多卡推理服务。使用方法请参考[此处](./docs/zh_cn/serving/proxy_server.md) diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 2b09aa3b74..084947095c 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -31,19 +31,11 @@ For more information on inference pipeline parameters, please refer to [here](./ ## Serving -LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup: +LMDeploy offers various serving methods, choosing one that best meet your requirements. -```shell -lmdeploy serve api_server internlm/internlm-chat-7b -``` - -The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`: - -```shell -lmdeploy serve api_client http://0.0.0.0:23333 -``` - -You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can read the API specification from [here](serving/restful_api.md). +- [Serving with openai compatible server](https://lmdeploy.readthedocs.io/en/latest/serving/restful_api.html) +- [Serving with docker](https://lmdeploy.readthedocs.io/en/latest/serving/restful_api.html#option-2-deploying-with-docker) +- [Serving with gradio](https://lmdeploy.readthedocs.io/en/latest/serving/gradio.html) ## Quantization diff --git a/docs/en/index.rst b/docs/en/index.rst index daccaae535..66e9c059b1 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -49,8 +49,8 @@ Welcome to LMDeploy's tutorials! :caption: serving serving/restful_api.md - serving/proxy_server.md serving/gradio.md + serving/proxy_server.md .. _quantization: .. toctree:: diff --git a/docs/en/serving/gradio.md b/docs/en/serving/gradio.md index 803dff50f5..7b223565ff 100644 --- a/docs/en/serving/gradio.md +++ b/docs/en/serving/gradio.md @@ -1,10 +1,25 @@ -# Steps to create a huggingface online demo +# Serving with Gradio -## create space +Starting an LLM model's gradio service with LMDeploy and interacting with the model on the WebUI is incredibly simple. + +```shell +pip install lmdeploy[serve] +lmdeploy serve gradio {model_path} +``` + +All it takes is one-line command, with the `{model_path}` replaced by the model ID from huggingface hub, such as `internlm/internlm2-chat-7b`, or the local path to the model. + +For detailed parameters of the command, please turn to `lmdeploy serve gradio -h` for help. + +## Create a huggingface demo + +If you want to create an online demo project for your model on huggingface, please follow the steps below. + +### Step 1: Create space First, register for a Hugging Face account. After successful registration, click on your profile picture in the upper right corner and select “New Space” to create one. Follow the Hugging Face guide to choose the necessary configurations, and you will have a blank demo space ready. -## A demo for LMDeploy +### Step 2: Develop demo's entrypoint `app.py` Replace the content of `app.py` in your space with the following code: @@ -12,7 +27,7 @@ Replace the content of `app.py` in your space with the following code: from lmdeploy.serve.gradio.turbomind_coupled import run_local from lmdeploy.messages import TurbomindEngineConfig -backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05) +backend_config = TurbomindEngineConfig(max_batch_size=8) model_path = 'internlm/internlm2-chat-7b' run_local(model_path, backend_config=backend_config, server_name="huggingface-space") ``` @@ -25,7 +40,7 @@ lmdeploy ## FAQs -- ZeroGPU compatibility issue. ZeroGPU is more suitable for inference methods similar to PyTorch, rather than Turbomind. You can switch to the PyTorch backend or enable standard GPUs. +- ZeroGPU compatibility issue. ZeroGPU is not suitable for LMDeploy turbomind engine. Please use the standard GPUs. Or, you can change the backend config in the above code to `PyTorchEngineConfig` to use the ZeroGPU. - Gradio version issue, versions above 4.0.0 are currently not supported. You can modify this in `app.py`, for example: ```python import os diff --git a/docs/en/serving/restful_api.md b/docs/en/serving/restful_api.md index de1ea9fa44..d092d9b288 100644 --- a/docs/en/serving/restful_api.md +++ b/docs/en/serving/restful_api.md @@ -1,34 +1,29 @@ -# Restful API +# Serving with OpenAI Compatible Server -## Launch Service +This article primarily discusses the deployment of a single LLM model across multiple GPUs on a single node, providing a service that is compatible with the OpenAI interface, as well as the usage of the service API. +For the sake of convenience, we refer to this service as `api_server`. Regarding parallel services with multiple models, please refer to the guide about [Request Distribution Server](./proxy_server.md). -The user can open the http url print by the following command in a browser. +In the following sections, we will first introduce two methods for starting the service, choosing the appropriate one based on your application scenario. -- **Please check the http url for the detailed api usage!!!** -- **Please check the http url for the detailed api usage!!!** -- **Please check the http url for the detailed api usage!!!** +Next, we focus on the definition of the service's RESTful API, explore the various ways to interact with the interface, and demonstrate how to try the service through the Swagger UI or LMDeploy CLI tools. -```shell -lmdeploy serve api_server ./workspace --server-name 0.0.0.0 --server-port ${server_port} --tp 1 -``` +Finally, we showcase how to integrate the service into a WebUI, providing you with a reference to easily set up a demonstration demo. -The parameters supported by api_server can be viewed through the command line `lmdeploy serve api_server -h`. +## Launch Service -We provide some RESTful APIs. Three of them are in OpenAI format. +Take the [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) model hosted on huggingface hub as an example, you can choose one the following methods to start the service. -- /v1/chat/completions -- /v1/models -- /v1/completions +### Option 1: Launching with lmdeploy CLI -However, we recommend users try -our own api `/v1/chat/interactive` which provides more arguments for users to modify. The performance is comparatively better. +```shell +lmdeploy serve api_server internlm/internlm2-chat-7b --server-port 23333 +``` -**Note** please, if you want to launch multiple requests, you'd better set different `session_id` for both -`/v1/chat/completions` and `/v1/chat/interactive` apis. Or, we will set them random values. +The arguments of `api_server` can be viewed through the command `lmdeploy serve api_server -h`, for instance, `--tp` to set tensor parallelism, `--session-len` to specify the max length of the context window, `--cache-max-entry-count` to adjust the GPU mem ratio for k/v cache etc. -## Deploy http service with docker +### Option 2: Deploying with docker -LMDeploy offers [official docker image](https://hub.docker.com/r/openmmlab/lmdeploy/tags) for deployment. The image can be used to run OpenAI compatible server. +With LMDeploy [official docker image](https://hub.docker.com/r/openmmlab/lmdeploy/tags), you can run OpenAI compatible server as follows: ```shell docker run --runtime nvidia --gpus all \ @@ -40,11 +35,60 @@ docker run --runtime nvidia --gpus all \ lmdeploy serve api_server internlm/internlm2-chat-7b ``` -Just like the previous section, user can try the Swagger UI with a web browser. +The parameters of `api_server` are the same with that mentioned in "[option 1](#option-1-launching-with-lmdeploy-cli)" section + +## RESTful API + +LMDeploy's RESTful API is compatible with the following three OpenAI interfaces: + +- /v1/chat/completions +- /v1/models +- /v1/completions + +Additionally, LMDeploy also defines `/v1/chat/interactive` to support interactive inference. The feature of interactive inference is that there's no need to pass the user conversation history as required by `v1/chat/completions`, since the conversation history will be cached on the server side. This method boasts excellent performance during multi-turn long context inference. + +You can overview and try out the offered RESTful APIs by the website `http://0.0.0.0:23333` as shown in the below image after launching the service successfully. + +![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459) + +Or, you can use the LMDeploy's built-in CLI tool to verify the service correctness right from the console. + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +lmdeploy serve api_client ${api_server_url} +``` + +If you need to integrate the service into your own projects or products, we recommend the following approach: + +### Integrate with `OpenAI` + +Here is an example of interaction with the endpoint `v1/chat/completions` service via the openai package. +Before running it, please install the openai package by `pip install openai` + +```python +from openai import OpenAI +client = OpenAI( + api_key='YOUR_API_KEY', + base_url="http://0.0.0.0:23333/v1" +) + +response = client.chat.completions.create( + model="internlm2-chat-7b", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": " provide three suggestions about time management"}, + ], + temperature=0.8, + top_p=0.8 +) +print(response) +``` + +You can invoke other OpenAI interfaces using similar methods. For more detailed information, please refer to the [OpenAI API guide](https://platform.openai.com/docs/guides/text-generation) -## python +### Integrate with lmdeploy `APIClient` -We have integrated the client-side functionalities of these services into the `APIClient` class. Below are some examples demonstrating how to invoke the `api_server` service on the client side. +Below are some examples demonstrating how to visit the service through `APIClient` If you want to use the `/v1/chat/completions` endpoint, you can try the following code: @@ -57,7 +101,7 @@ for item in api_client.chat_completions_v1(model=model_name, messages=messages): print(item) ``` -For the `/v1/completions` endpoint. If you want to use the `/v1/completions` endpoint, you can try: +For the `/v1/completions` endpoint, you can try: ```python from lmdeploy.serve.openai.api_client import APIClient @@ -67,23 +111,29 @@ for item in api_client.completions_v1(model=model_name, prompt='hi'): print(item) ``` -Lmdeploy supports maintaining session histories on the server for `/v1/chat/interactive` api. We disable the -feature by default. +As for `/v1/chat/interactive`,we disable the feature by default. Please open it by setting `interactive_mode = True`. If you don't, it falls back to openai compatible interfaces. -- On interactive mode, the chat history is kept on the server. In a multiple rounds of conversation, you should set - `interactive_mode = True` and the same `session_id` (can't be -1, it's the default number) to `/v1/chat/interactive` for requests. -- On normal mode, no chat history is kept on the server. - -The interactive mode can be controlled by the `interactive_mode` boolean parameter. The following is an example of normal mode. If you want to experience the interactive mode, simply pass in `interactive_mode=True`. +Keep in mind that `session_id` indicates an identical sequence and all requests belonging to the same sequence must share the same `session_id`. +For instance, in a sequence with 10 rounds of chatting requests, the `session_id` in each request should be the same. ```python from lmdeploy.serve.openai.api_client import APIClient -api_client = APIClient('http://{server_ip}:{server_port}') -for item in api_client.chat_interactive_v1(prompt='hi'): - print(item) +api_client = APIClient(f'http://{server_ip}:{server_port}') +messages = [ + "hi, what's your name?", + "who developed you?", + "Tell me more about your developers", + "Summarize the information we've talked so far" +] +for message in messages: + for item in api_client.chat_interactive_v1(prompt=message, + session_id=1, + interactive_mode=True, + stream=False): + print(item) ``` -## Java/Golang/Rust +### Integrate with Java/Golang/Rust May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client. Here is an example: @@ -102,29 +152,17 @@ rust/src: apis lib.rs models ``` -## cURL +### Integrate with cURL -cURL is a tool for observing the output of the api. +cURL is a tool for observing the output of the RESTful APIs. -List Models: +- list served models `v1/models` ```bash curl http://{server_ip}:{server_port}/v1/models ``` -Interactive Chat: - -```bash -curl http://{server_ip}:{server_port}/v1/chat/interactive \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "Hello! How are you?", - "session_id": 1, - "interactive_mode": true - }' -``` - -Chat Completions: +- chat `v1/chat/completions` ```bash curl http://{server_ip}:{server_port}/v1/chat/completions \ @@ -135,7 +173,7 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \ }' ``` -Text Completions: +- text completions `v1/completions` ```shell curl http://{server_ip}:{server_port}/v1/completions \ @@ -146,18 +184,23 @@ curl http://{server_ip}:{server_port}/v1/completions \ }' ``` -## CLI client - -There is a client script for restful api server. +- interactive chat `v1/chat/interactive` -```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -lmdeploy serve api_client api_server_url +```bash +curl http://{server_ip}:{server_port}/v1/chat/interactive \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Hello! How are you?", + "session_id": 1, + "interactive_mode": true + }' ``` -## webui through gradio +## Integrate with WebUI + +LMDeploy utilizes `gradio` or [OpenAOE](https://github.com/InternLM/OpenAOE) to integrate a web ui for `api_server` -You can also test restful-api through webui. +### Option 1: gradio ```shell # api_server_url is what printed in api_server.py, e.g. http://localhost:23333 @@ -166,9 +209,7 @@ You can also test restful-api through webui. lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port ${gradio_ui_port} ``` -## webui through OpenAOE - -You can use [OpenAOE](https://github.com/InternLM/OpenAOE) for seamless integration with LMDeploy. +### Option 2: OpenAOE ```shell pip install -U openaoe @@ -191,7 +232,3 @@ Please refer to the [guidance](https://github.com/InternLM/OpenAOE/blob/main/doc 5. If you need to adjust other default parameters of the session, such as the content of fields like system. You can directly pass in the initialization parameters of the [dialogue template](https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py). For example, for the internlm-chat-7b model, you can set the `--meta-instruction` parameter when starting the `api_server`. 6. Regarding the stop words, we only support characters that encode into a single index. Furthermore, there may be multiple indexes that decode into results containing the stop word. In such cases, if the number of these indexes is too large, we will only use the index encoded by the tokenizer. If you want use a stop symbol that encodes into multiple indexes, you may consider performing string matching on the streaming client side. Once a successful match is found, you can then break out of the streaming loop. - -## request distribution service - -Please refer to our [request distributor server](./proxy_server.md) diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index f5d6b1c13f..26fd61ab71 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -31,19 +31,11 @@ print(response) ## 推理服务 -LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例: +LMDeploy 提供了多种部署模型推理服务的方式,总有一款适合你。 -```shell -lmdeploy serve api_server internlm/internlm-chat-7b -``` - -服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话: - -```shell -lmdeploy serve api_client http://0.0.0.0:23333 -``` - -除了`api_client`,你还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](serving/restful_api.md),了解各接口的定义和使用方法。 +- [部署类 openai 的服务](https://lmdeploy.readthedocs.io/zh-cn/latest//serving/restful_api.html) +- [通过 docker 部署服务](https://lmdeploy.readthedocs.io/zh-cn/latest/serving/restful_api.html#docker) +- [部署 gradio 服务](https://lmdeploy.readthedocs.io/zh-cn/latest/serving/gradio.html) ## 模型量化 diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index e89bfa661b..265fb716d2 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -50,8 +50,8 @@ :caption: 服务 serving/restful_api.md - serving/proxy_server.md serving/gradio.md + serving/proxy_server.md .. _量化: diff --git a/docs/zh_cn/serving/gradio.md b/docs/zh_cn/serving/gradio.md index fe1e01af3f..3e70f68856 100644 --- a/docs/zh_cn/serving/gradio.md +++ b/docs/zh_cn/serving/gradio.md @@ -1,11 +1,26 @@ -# 从 LMDeploy 创建一个 huggingface 的在线 demo +# 部署 gradio 服务 -## 创建 space +通过 LMDeploy 启动 LLM 模型的 gradio 服务,并在 WebUI 上和模型对话特别简单,一条命令即可。 + +```shell +pip install lmdeploy[serve] +lmdeploy serve gradio {model_path} +``` + +把上面命令中的 `{model_path}` 换成 huggingface hub 上的模型 id,比如 internlm/internlm2-chat-7b,或者换成模型的本地路径就可以了。 + +关于命令的详细参数,请使用 `lmdeploy serve gradio --help` 查阅。 + +## 创建 huggingface demo + +如果想要在 huggingface 上创建模型的在线演示项目,请按以下步骤进行。 + +### 第一步:创建 space 首先,注册一个 huggingface 的账号,注册成功后,可以点击右上角头像,选择 New Space 创建。 根据 huggingface 的引导选择需要的配置,完成后即可得到一个空白的 demo。 -## 使用 LMDeploy 的 demo +### 第二步:编写 demo 入口代码 app.py 以 `internlm/internlm2-chat-7b` 模型为例,将 space 空间中的`app.py`内容填写为: @@ -13,7 +28,7 @@ from lmdeploy.serve.gradio.turbomind_coupled import run_local from lmdeploy.messages import TurbomindEngineConfig -backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05) +backend_config = TurbomindEngineConfig(max_batch_size=8) model_path = 'internlm/internlm2-chat-7b' run_local(model_path, backend_config=backend_config, server_name="huggingface-space") ``` @@ -26,7 +41,7 @@ lmdeploy ## FAQs -- ZeroGPU 适配问题。ZeroGPU 更适合类似 PyTorch 这样的推理方式,而非 Turbomind。可以改用 pytorch 后端,或者启用普通 GPU。 +- ZeroGPU 适配问题。ZeroGPU不适用 LMDeploy Turbomind 引擎,请选择普通 GPU,或者把上述代码中的 backend_config 改成 PyTorchEngineConfig,就可以用 ZeroGPU 了。 - gradio 版本问题,目前不支持 4.0.0 以上版本,可以在 `app.py` 中修改,类似: ```python import os diff --git a/docs/zh_cn/serving/restful_api.md b/docs/zh_cn/serving/restful_api.md index bb0e4da12d..76b76a0a63 100644 --- a/docs/zh_cn/serving/restful_api.md +++ b/docs/zh_cn/serving/restful_api.md @@ -1,31 +1,29 @@ -# Restful API +# 部署类 openai 服务 -## 启动服务 +本文主要介绍单个模型在单机多卡环境下,部署兼容 openai 接口服务的方式,以及服务接口的用法。为行文方便,我们把该服务名称为 `api_server`。对于多模型的并行服务,请阅读[请求分发服务器](./proxy_server.md)一文。 -用户将下面命令输出的 http url 复制到浏览器打开,详细查看所有的 API 及其使用方法。 -请一定查看`http://{server_ip}:{server_port}`!!! -请一定查看`http://{server_ip}:{server_port}`!!! -请一定查看`http://{server_ip}:{server_port}`!!! -重要的事情说三遍。 +在这篇文章中, 我们首先介绍服务启动的两种方法,你可以根据应用场景,选择合适的。 -```shell -lmdeploy serve api_server ./workspace --server-name 0.0.0.0 --server-port ${server_port} --tp 1 -``` +其次,我们重点介绍服务的 RESTful API 定义,以及接口使用的方式,并展示如何通过 Swagger UI、LMDeploy CLI 工具体验服务功能 -api_server 启动时支持的参数可以通过命令行`lmdeploy serve api_server -h`查看。 +最后,向大家演示把服务接入到 WebUI 的方式,你可以参考它简单搭建一个演示 demo。 -我们提供的 restful api,其中三个仿照 OpenAI 的形式。 +## 启动服务 -- /v1/chat/completions -- /v1/models -- /v1/completions +以 huggingface hub 上的 [internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) 模型为例,你可以任选以下方式之一,启动推理服务。 -不过,我们建议用户用我们提供的另一个 API: `/v1/chat/interactive`。 -它有更好的性能,提供更多的参数让用户自定义修改。 +### 方式一:使用 lmdeploy cli 工具 -## 用 docker 部署 http 服务 +```shell +lmdeploy serve api_server internlm/internlm2-chat-7b --server-port 23333 +``` + +api_server 启动时的参数可以通过命令行`lmdeploy serve api_server -h`查看。 +比如,`--tp` 设置张量并行,`--session-len` 设置推理的最大上下文窗口长度,`--cache-max-entry-count` 调整 k/v cache 的内存使用比例等等。 -LMDeploy 提供了官方[镜像](https://hub.docker.com/r/openmmlab/lmdeploy/tags)。使用这个镜像,可以运行兼容 OpenAI 的服务。下面是使用示例: +### 方式二:使用 docker + +使用 LMDeploy 官方[镜像](https://hub.docker.com/r/openmmlab/lmdeploy/tags),可以运行兼容 OpenAI 的服务。下面是使用示例: ```shell docker run --runtime nvidia --gpus all \ @@ -37,16 +35,64 @@ docker run --runtime nvidia --gpus all \ lmdeploy serve api_server internlm/internlm2-chat-7b ``` -然后像上面一样使用浏览器试用 Swagger UI 即可。 +在这个例子中,`lmdeploy server api_server` 的命令参数与方式一一致。 + +## RESTful API + +LMDeploy 的 RESTful API 兼容了 OpenAI 以下 3 个接口: + +- /v1/chat/completions +- /v1/models +- /v1/completions + +此外,LMDeploy 还定义了 `/v1/chat/interactive`,用来支持交互式推理。交互式推理的特点是不用像`v1/chat/completions`传入用户对话历史,因为对话历史会被缓存在服务端。 +这种方式在多轮次的长序列推理时,拥有很好的性能。 + +服务启动后,你可以在浏览器中打开网页 http://0.0.0.0:23333,通过 Swagger UI 查看接口的详细说明,并且也可以直接在网页上操作,体验每个接口的用法,如下图所示。 + +![swagger_ui](https://github.com/InternLM/lmdeploy/assets/4560679/b891dd90-3ffa-4333-92b2-fb29dffa1459) + +也可以使用 LMDeploy 自带的 CLI 工具,在控制台验证服务的正确性。 + +```shell +# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 +lmdeploy serve api_client ${api_server_url} +``` + +若需要把服务集成到自己的项目或者产品中,我们推荐以下用法: + +### 使用 openai 接口 + +以下代码是通过 openai 包使用 `v1/chat/completions` 服务的例子。运行之前,请先安装 openai 包: `pip install openai`。 + +```python +from openai import OpenAI +client = OpenAI( + api_key='YOUR_API_KEY', + base_url="http://0.0.0.0:23333/v1" +) + +response = client.chat.completions.create( + model="internlm2-chat-7b", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": " provide three suggestions about time management"}, + ], + temperature=0.8, + top_p=0.8 +) +print(response) +``` + +关于其他 openai 接口的调用,也可以如法炮制。详情请参考 openai 官方[文档](https://platform.openai.com/docs/guides/text-generation) -## python +### 使用 lmdeploy `APIClient` 接口 -我们将这些服务的客户端功能集成在 `APIClient` 类中。下面是一些例子,展示如何在客户端调用 `api_server` 服务。 如果你想用 `/v1/chat/completions` 接口,你可以尝试下面代码: ```python from lmdeploy.serve.openai.api_client import APIClient -api_client = APIClient('http://{server_ip}:{server_port}') +api_client = APIClient(f'http://{server_ip}:{server_port}') model_name = api_client.available_models[0] messages = [{"role": "user", "content": "Say this is a test!"}] for item in api_client.chat_completions_v1(model=model_name, messages=messages): @@ -57,28 +103,35 @@ for item in api_client.chat_completions_v1(model=model_name, messages=messages): ```python from lmdeploy.serve.openai.api_client import APIClient -api_client = APIClient('http://{server_ip}:{server_port}') +api_client = APIClient(f'http://{server_ip}:{server_port}') model_name = api_client.available_models[0] for item in api_client.completions_v1(model=model_name, prompt='hi'): print(item) ``` -LMDeploy 的 `/v1/chat/interactive` api 支持将对话内容管理在服务端,但是我们默认关闭。如果想尝试,请阅读以下介绍: +关于 `/v1/chat/interactive` 接口,我们默认是关闭的。在使用时,请设置`interactive_mode = True`打开它。否则,它会退化为 openai 接口。 -- 交互模式下,对话历史保存在 server。在一次完整的多轮对话中,所有请求设置`interactive_mode = True`, `session_id`保持相同 (不为 -1,这是缺省值)。 -- 非交互模式下,server 不保存历史记录。 - -交互模式可以通过 `interactive_mode` 布尔量参数控制。下面是一个普通模式的例子, -如果要体验交互模式,将 `interactive_mode=True` 传入即可。 +在交互式推理中,每个对话序列的 id 必须唯一,所有属于该独立的对话请求,必须使用相同的 id。这里的 id 对应与接口中的 `session_id`。 +比如,一个对话序列中,有 10 轮对话请求,那么每轮对话请求中的 `session_id` 都要相同。 ```python from lmdeploy.serve.openai.api_client import APIClient -api_client = APIClient('http://{server_ip}:{server_port}') -for item in api_client.chat_interactive_v1(prompt='hi'): - print(item) +api_client = APIClient(f'http://{server_ip}:{server_port}') +messages = [ + "hi, what's your name?", + "who developed you?", + "Tell me more about your developers", + "Summarize the information we've talked so far" +] +for message in messages: + for item in api_client.chat_interactive_v1(prompt=message, + session_id=1, + interactive_mode=True, + stream=False): + print(item) ``` -## Java/Golang/Rust +### 使用 Java/Golang/Rust 可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) 将 `http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。 下面是一个使用示例: @@ -97,29 +150,17 @@ rust/src: apis lib.rs models ``` -## cURL +### 使用 cURL cURL 也可以用于查看 API 的输出结果 -查看模型列表: +- 查看模型列表 `v1/models` ```bash curl http://{server_ip}:{server_port}/v1/models ``` -Interactive Chat: - -```bash -curl http://{server_ip}:{server_port}/v1/chat/interactive \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "Hello! How are you?", - "session_id": 1, - "interactive_mode": true - }' -``` - -Chat Completions: +- 对话 `v1/chat/completions` ```bash curl http://{server_ip}:{server_port}/v1/chat/completions \ @@ -130,7 +171,7 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \ }' ``` -Text Completions: +- 文本补全 `v1/completions` ```shell curl http://{server_ip}:{server_port}/v1/completions \ @@ -141,18 +182,23 @@ curl http://{server_ip}:{server_port}/v1/completions \ }' ``` -## CLI client +- 交互式对话 `v1/chat/interactive` -restful api 服务可以通过客户端测试,例如 - -```shell -# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333 -lmdeploy serve api_client api_server_url +```bash +curl http://{server_ip}:{server_port}/v1/chat/interactive \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Hello! How are you?", + "session_id": 1, + "interactive_mode": true + }' ``` -## webui through gradio +## 接入 WebUI + +LMDeploy 提供 gradio 和 [OpenAOE](https://github.com/InternLM/OpenAOE) 两种方式,为 api_server 接入 WebUI。 -也可以直接用 webui 测试使用 restful-api。 +### 方式一:通过 gradio 接入 ```shell # api_server_url 就是 api_server 产生的,比如 http://localhost:23333 @@ -161,9 +207,7 @@ lmdeploy serve api_client api_server_url lmdeploy serve gradio api_server_url --server-name ${gradio_ui_ip} --server-port ${gradio_ui_port} ``` -## webui through OpenAOE - -可以使用 [OpenAOE](https://github.com/InternLM/OpenAOE) 无缝接入restful api服务. +### 方式二:通过 OpenAOE 接入 ```shell pip install -U openaoe @@ -185,7 +229,3 @@ openaoe -f /path/to/your/config-template.yaml 5. 如需调整会话默认的其他参数,比如 system 等字段的内容,可以直接将[对话模板](https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py)初始化参数传入。比如 internlm-chat-7b 模型,可以通过启动`api_server`时,设置`--meta-instruction`参数。 6. 关于停止符,我们只支持编码后为单个 index 的字符。此外,可能存在多种 index 都会解码出带有停止符的结果。对于这种情况,如果这些 index 数量太多,我们只会采用 tokenizer 编码出的 index。而如果你想要编码后为多个 index 的停止符,可以考虑在流式客户端做字符串匹配,匹配成功后跳出流式循环即可。 - -## 多机并行服务 - -请参考我们的 [请求分发服务器](./proxy_server.md) From e710c4c607a2003661444469a555aa31da58a1a8 Mon Sep 17 00:00:00 2001 From: zhulinJulia24 <145004780+zhulinJulia24@users.noreply.github.com> Date: Wed, 6 Mar 2024 21:51:36 +0800 Subject: [PATCH 17/17] Parallelize testcase and refactor test workflow (#1254) * refactor test case * refactor test case * refactor testcase * fix cuda allocate * fix cuda-prefix in pr run * Update daily_ete_test.yml * chage internlm2 model coverage to 20b in testcase * chage internlm2 model coverage to 20b in testcase * chage internlm2 model coverage to 20b in testcase * fix mp blocked by allocate cuda * add kvint8 and w4a16 chat cover * modify timeout for each step * fix lint * update prompt and pr trigger * update runner config * Update daily_ete_test.yml * change job name * parallize testcase * parallise testcase * fix dialog * fix chat interactive * Update daily_ete_test.yml * fix * fix condition * fix condition * fix pr test * fix trigger condition * fix workflow * fix workflow * Update daily_ete_test.yml --------- Co-authored-by: zhulin1 --- .github/workflows/daily_ete_test.yml | 106 +++++++++++----- .github/workflows/pr_ete_test.yml | 3 +- autotest/config.yaml | 112 ++++++++--------- .../pipeline/test_pipeline_turbomind_func.py | 16 +-- .../test_pipeline_turbomind_longtext_func.py | 88 ++++++++++++++ autotest/prompt_case.yaml | 4 + .../chat/test_command_chat_hf_pytorch.py | 43 +++++-- .../chat/test_command_chat_hf_turbomind.py | 45 +++++-- .../tools/chat/test_command_chat_workspace.py | 48 ++++++-- autotest/tools/convert/test_convert.py | 45 ++++--- .../tools/pipeline/pipeline_chat_script.py | 5 +- .../pipeline/test_pipeline_chat_pytorch.py | 36 +++++- .../pipeline/test_pipeline_chat_turbomind.py | 30 ++++- .../quantization/test_quantization_kvint8.py | 27 +++-- .../test_quantization_kvint8_w4a16.py | 30 +++-- .../quantization/test_quantization_w4a16.py | 35 +++--- .../quantization/test_quantization_w8a8.py | 29 +++-- .../restful/test_restful_chat_pytorch.py | 114 ++++++++++++------ .../restful/test_restful_chat_turbomind.py | 106 +++++++++++----- autotest/utils/config_utils.py | 72 ++++++++++- autotest/utils/get_run_config.py | 13 +- autotest/utils/pipeline_chat.py | 12 +- autotest/utils/quantization_utils.py | 7 +- autotest/utils/run_client_chat.py | 38 +++--- autotest/utils/run_restful_chat.py | 14 ++- 25 files changed, 780 insertions(+), 298 deletions(-) create mode 100644 autotest/interface/pipeline/test_pipeline_turbomind_longtext_func.py diff --git a/.github/workflows/daily_ete_test.yml b/.github/workflows/daily_ete_test.yml index afc94417f8..f2279536ad 100644 --- a/.github/workflows/daily_ete_test.yml +++ b/.github/workflows/daily_ete_test.yml @@ -2,8 +2,29 @@ name: daily_ete_test on: workflow_dispatch: + inputs: + repo_org: + required: false + description: 'Tested repository organization name. Default is InternLM' + type: string + default: 'InternLM/lmdeploy' + repo_ref: + required: false + description: 'Set branch or tag or commit id. Default is "main"' + type: string + default: 'main' + backend: + required: true + description: 'Set backend testcase filter: turbomind or pytorch or turbomind, pytorch. Default is "["turbomind", "pytorch"]"' + type: string + default: "['turbomind', 'pytorch']" + model: + required: true + description: 'Set testcase module filter: chat, restful, pipeline, quantization. Default contains all models' + type: string + default: "['quantization','convert','pipeline','restful','chat','interface-pipeline']" schedule: - - cron: '00 18 * * *' + - cron: '00 21 * * *' env: HOST_PIP_CACHE_DIR: /nvme/github-actions/pip-cache @@ -13,7 +34,7 @@ env: jobs: test_functions: runs-on: [self-hosted, linux-a100] - timeout-minutes: 420 + timeout-minutes: 240 env: REPORT_DIR: /nvme/qa_test_models/test-reports container: @@ -23,6 +44,7 @@ jobs: - /nvme/github-actions/pip-cache:/root/.cache/pip - /nvme/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models + - /mnt/bigdisk/qa_test_models:/mnt/bigdisk/qa_test_models - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Setup systems @@ -33,7 +55,10 @@ jobs: dpkg -i /root/packages/allure_2.24.1-1_all.deb rm -rf /var/lib/apt/lists/* - name: Clone repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 + with: + repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }} + ref: ${{github.event.inputs.repo_ref || 'main'}} - name: Install pytorch run: | python3 -m pip cache dir @@ -68,64 +93,89 @@ jobs: run: | python3 -m pip list lmdeploy check_env + rm -rf allure-results - name: Test lmdeploy - quantization w4a16 continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'quantization')) run: | pytest autotest/tools/quantization/test_quantization_w4a16.py -m 'not pr_test' -n 8 --alluredir=allure-results --clean-alluredir - name: Test lmdeploy - quantization kv int8 continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'quantization')) run: | pytest autotest/tools/quantization/test_quantization_kvint8.py -n 8 --alluredir=allure-results - name: Test lmdeploy - quantization w8a8 continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'pytorch') && contains(fromJSON(github.event.inputs.model), 'quantization')) run: | pytest autotest/tools/quantization/test_quantization_w8a8.py -n 8 --alluredir=allure-results - name: Test lmdeploy - quantization kv int8 and w4a16 continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'quantization')) run: | pytest autotest/tools/quantization/test_quantization_kvint8_w4a16.py -n 8 --alluredir=allure-results - name: Test lmdeploy - convert continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'convert')) run: | - pytest autotest/tools/convert -m 'not pr_test' -n 6 --alluredir=allure-results --dist loadgroup - - name: Test lmdeploy - interface turbomind case + pytest autotest/tools/convert -m 'not pr_test' -n 8 --alluredir=allure-results + - name: Test lmdeploy - chat workspace continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'chat')) timeout-minutes: 20 run: | - pytest autotest/interface/pipeline/test_pipeline_turbomind_func.py -m 'not pr_test' --alluredir=allure-results - - name: Test lmdeploy - pipeline turbomind - continue-on-error: true - timeout-minutes: 45 - run: pytest autotest/tools/pipeline/test_pipeline_chat_turbomind.py -m 'not pr_test' --alluredir=allure-results - - name: Test lmdeploy - pipeline torch + pytest autotest/tools/chat/test_command_chat_workspace.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/chat/test_command_chat_workspace.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results + - name: Test lmdeploy - chat hf turbomind continue-on-error: true - timeout-minutes: 75 - run: pytest autotest/tools/pipeline/test_pipeline_chat_pytorch.py -m 'not pr_test' --alluredir=allure-results - - name: Test lmdeploy - restful turbomind + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'chat')) + timeout-minutes: 20 + run: | + pytest autotest/tools/chat/test_command_chat_hf_turbomind.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/chat/test_command_chat_hf_turbomind.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results + - name: Test lmdeploy - chat hf torch continue-on-error: true - timeout-minutes: 60 - run: pytest autotest/tools/restful/test_restful_chat_turbomind.py -m 'not pr_test' --alluredir=allure-results - - name: Test lmdeploy - restful torch + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'pytorch') && contains(fromJSON(github.event.inputs.model), 'chat')) + timeout-minutes: 20 + run: | + pytest autotest/tools/chat/test_command_chat_hf_pytorch.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/chat/test_command_chat_hf_pytorch.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results + - name: Test lmdeploy - pipeline turbomind continue-on-error: true - timeout-minutes: 80 - run: pytest autotest/tools/restful/test_restful_chat_pytorch.py -m 'not pr_test' --alluredir=allure-results - - name: Test lmdeploy - chat workspace + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'pipeline')) + timeout-minutes: 25 + run: | + pytest autotest/tools/pipeline/test_pipeline_chat_turbomind.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/pipeline/test_pipeline_chat_turbomind.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results + - name: Test lmdeploy - restful turbomind continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'restful')) timeout-minutes: 30 run: | - pytest autotest/tools/chat/test_command_chat_workspace.py -m 'not pr_test' -n 4 --alluredir=allure-results - - name: Test lmdeploy - chat hf turbomind + pytest autotest/tools/restful/test_restful_chat_turbomind.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/restful/test_restful_chat_turbomind.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results + - name: Test lmdeploy - interface pipeline turbomind case continue-on-error: true - timeout-minutes: 45 + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'turbomind') && contains(fromJSON(github.event.inputs.model), 'interface-pipeline')) + timeout-minutes: 20 run: | - pytest autotest/tools/chat/test_command_chat_hf_turbomind.py -m 'not pr_test' -n 4 --alluredir=allure-results - - name: Test lmdeploy - chat hf torch + pytest autotest/interface/pipeline/test_pipeline_turbomind_func.py -m 'not pr_test' --alluredir=allure-results + - name: Test lmdeploy - pipeline torch + continue-on-error: true + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'pytorch') && contains(fromJSON(github.event.inputs.model), 'pipeline')) + timeout-minutes: 25 + run: | + pytest autotest/tools/pipeline/test_pipeline_chat_pytorch.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/pipeline/test_pipeline_chat_pytorch.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results + - name: Test lmdeploy - restful torch continue-on-error: true - timeout-minutes: 60 + if: github.event_name == 'schedule' || (contains(fromJSON(github.event.inputs.backend), 'pytorch') && contains(fromJSON(github.event.inputs.model), 'restful')) + timeout-minutes: 40 run: | - pytest autotest/tools/chat/test_command_chat_hf_pytorch.py -m 'not pr_test' -n 4 --alluredir=allure-results + pytest autotest/tools/restful/test_restful_chat_pytorch.py -m 'gpu_num_1 and not pr_test' -n 8 --alluredir=allure-results + pytest autotest/tools/restful/test_restful_chat_pytorch.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=allure-results - name: Test lmdeploy - rerun all fail cases - timeout-minutes: 60 + timeout-minutes: 30 run: | pytest autotest --lf --alluredir=allure-results - name: Generate reports diff --git a/.github/workflows/pr_ete_test.yml b/.github/workflows/pr_ete_test.yml index a41e639f30..08bf24b4b7 100644 --- a/.github/workflows/pr_ete_test.yml +++ b/.github/workflows/pr_ete_test.yml @@ -34,6 +34,7 @@ jobs: - /nvme/share_data/github-actions/pip-cache:/root/.cache/pip - /nvme/share_data/github-actions/packages:/root/packages - /nvme/qa_test_models:/nvme/qa_test_models + - /mnt/bigdisk/qa_test_models:/mnt/bigdisk/qa_test_models - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro steps: - name: Setup systems @@ -81,7 +82,7 @@ jobs: lmdeploy check_env - name: Test lmdeploy timeout-minutes: 120 - run: CUDA_VISIBLE_DEVICES=5,6 pytest autotest -m pr_test --alluredir=allure-results --clean-alluredir + run: CUDA_VISIBLE_DEVICES=5,6 pytest autotest -m pr_test -v -s --alluredir=allure-results --clean-alluredir - name: Generate reports if: always() run: | diff --git a/autotest/config.yaml b/autotest/config.yaml index 60100c6fa8..75988f2891 100644 --- a/autotest/config.yaml +++ b/autotest/config.yaml @@ -1,4 +1,4 @@ -model_path: /nvme/qa_test_models +model_path: /mnt/bigdisk/qa_test_models dst_path: /nvme/qa_test_models/autotest_model log_path: /nvme/qa_test_models/autotest_model/log dataset_path: /nvme/qa_test_models/...dataset @@ -13,67 +13,69 @@ tp_config: turbomind_model: - - llama-2-7b-chat - - internlm2-chat-1_8b - - internlm-chat-7b - - internlm-chat-20b - - internlm2-chat-7b - - internlm2-chat-20b - - Qwen-7B-Chat - - Qwen-14B-Chat - - llama2-chat-7b-w4 - - Baichuan2-7B-Chat - - Yi-6B-Chat - - internlm2-1_8b - - internlm2-20b - - CodeLlama-7b-Instruct-hf + - meta-llama/Llama-2-7b-chat + - internlm/internlm2-chat-1_8b + - internlm/internlm-chat-7b + - internlm/internlm-chat-20b + - internlm/internlm2-chat-7b + - internlm/internlm2-chat-20b + - internlm/internlm2-chat-7b-4bits + - internlm/internlm2-chat-20b-4bits + - Qwen/Qwen-7B-Chat + - Qwen/Qwen-14B-Chat + - lmdeploy/llama2-chat-7b-w4 + - baichuan-inc/Baichuan2-7B-Chat + - 01-ai/Yi-6B-Chat + - internlm/internlm2-1_8b + - internlm/internlm2-20b + - codellama/CodeLlama-7b-Instruct-hf pytorch_model: - - llama-2-7b-chat - - internlm-chat-7b - - internlm-chat-20b - - internlm2-chat-7b - - internlm2-chat-20b - - Baichuan2-7B-Chat - - Baichuan2-13B-Chat - - chatglm2-6b - - falcon-7b - - Yi-6B-Chat - - internlm2-1_8b - - internlm2-20b - - Qwen1.5-7B-Chat - - Mistral-7B-Instruct-v0.1 - - Mixtral-8x7B-Instruct-v0.1 - - gemma-7b-it - - deepseek-moe-16b-chat + - meta-llama/Llama-2-7b-chat + - internlm/internlm-chat-7b + - internlm/internlm-chat-20b + - internlm/internlm2-chat-7b + - internlm/internlm2-chat-20b + - baichuan-inc/Baichuan2-7B-Chat + - baichuan-inc/Baichuan2-13B-Chat + - THUDM/chatglm2-6b + - tiiuae/falcon-7b + - 01-ai/Yi-6B-Chat + - internlm/internlm2-1_8b + - internlm/internlm2-20b + - Qwen/Qwen1.5-7B-Chat + - mistralai/Mistral-7B-Instruct-v0.1 + - mistralai/Mixtral-8x7B-Instruct-v0.1 + - google/gemma-7b-it + - deepseek-ai/deepseek-moe-16b-chat quatization_case_config: w4a16: - - llama-2-7b-chat - - internlm-chat-20b - - Qwen-7B-Chat - - Qwen-14B-Chat - - internlm2-chat-20b - - Baichuan2-7B-Chat - - internlm2-20b + - meta-llama/Llama-2-7b-chat + - internlm/internlm-chat-20b + - Qwen/Qwen-7B-Chat + - Qwen/Qwen-14B-Chat + - internlm/internlm2-chat-20b + - baichuan-inc/Baichuan2-7B-Chat + - internlm/internlm2-20b kvint8: # more models are supported kvint8 quantization, but the chat response are not good, already removed - - llama-2-7b-chat - - internlm-chat-20b - - internlm2-chat-20b + - meta-llama/Llama-2-7b-chat + - internlm/internlm-chat-20b + - internlm/internlm2-chat-20b kvint8_w4a16: - - llama-2-7b-chat - - internlm-chat-20b - - internlm2-chat-20b - - internlm2-20b - - Qwen-7B-Chat - - Qwen-14B-Chat - - Baichuan2-7B-Chat + - meta-llama/Llama-2-7b-chat + - internlm/internlm-chat-20b + - internlm/internlm2-chat-20b + - internlm/internlm2-20b + - Qwen/Qwen-7B-Chat + - Qwen/Qwen-14B-Chat + - baichuan-inc/Baichuan2-7B-Chat w8a8: - - llama-2-7b-chat - - internlm-chat-20b - - internlm2-chat-20b - - internlm2-chat-7b - - Yi-6B-Chat - - internlm2-20b + - meta-llama/Llama-2-7b-chat + - internlm/internlm-chat-20b + - internlm/internlm2-chat-20b + - internlm/internlm2-chat-7b + - 01-ai/Yi-6B-Chat + - internlm/internlm2-20b diff --git a/autotest/interface/pipeline/test_pipeline_turbomind_func.py b/autotest/interface/pipeline/test_pipeline_turbomind_func.py index 8251fae347..64a07b3ddb 100644 --- a/autotest/interface/pipeline/test_pipeline_turbomind_func.py +++ b/autotest/interface/pipeline/test_pipeline_turbomind_func.py @@ -10,7 +10,7 @@ @pytest.mark.flaky(reruns=0) class TestPipelineTurbomindFuncRegression: - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_backend_config_tp(self, config, model): with pytest.raises(AssertionError, match='tp should be 2\\^n'): model_path = '/'.join([config.get('model_path'), model]) @@ -18,7 +18,7 @@ def test_backend_config_tp(self, config, model): pipe = pipeline(model_path, backend_config=backend_config) del pipe - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_backend_config_session_len(self, config, model): model_path = '/'.join([config.get('model_path'), model]) backend_config = TurbomindEngineConfig(session_len=10) @@ -29,7 +29,7 @@ def test_backend_config_session_len(self, config, model): assert response[i].finish_reason == 'length', str(response[i]) assert response[i].generate_token_len == 0, str(response[i]) - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_gen_config_test(self, config, model): model_path = '/'.join([config.get('model_path'), model]) pipe = pipeline(model_path) @@ -111,7 +111,7 @@ def test_gen_config_test(self, config, model): del pipe - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def future_test_backend_config_cache_max_entry_count(self, config, model): model_path = '/'.join([config.get('model_path'), model]) backend_config = TurbomindEngineConfig(cache_max_entry_count=-1) @@ -122,7 +122,7 @@ def future_test_backend_config_cache_max_entry_count(self, config, model): with assume: assert response[i].finish_reason == 'length', str(response[i]) - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_backend_config_max_batch_size2(self, config, model): model_path = '/'.join([config.get('model_path'), model]) backend_config = TurbomindEngineConfig(max_batch_size=-1) @@ -140,7 +140,7 @@ def test_backend_config_max_batch_size2(self, config, model): with assume: assert response[i].text == '', str(response[i]) - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_pipeline_batch_infer(self, config, model): model_path = '/'.join([config.get('model_path'), model]) pipe = pipeline(model_path) @@ -160,7 +160,7 @@ def test_pipeline_batch_infer(self, config, model): with assume: assert response[i].session_id == i - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_pipeline_stream_infer(self, config, model): model_path = '/'.join([config.get('model_path'), model]) pipe = pipeline(model_path) @@ -207,7 +207,7 @@ def test_pipeline_stream_infer(self, config, model): with assume: assert outputs_list[-1].finish_reason is not None, str(output) - @pytest.mark.parametrize('model', ['internlm2-chat-20b']) + @pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_pipeline_stream_infer2(self, config, model): model_path = '/'.join([config.get('model_path'), model]) pipe = pipeline(model_path) diff --git a/autotest/interface/pipeline/test_pipeline_turbomind_longtext_func.py b/autotest/interface/pipeline/test_pipeline_turbomind_longtext_func.py new file mode 100644 index 0000000000..13bfd8aff3 --- /dev/null +++ b/autotest/interface/pipeline/test_pipeline_turbomind_longtext_func.py @@ -0,0 +1,88 @@ +import pytest +from utils.get_run_config import get_tp_num + +from lmdeploy import TurbomindEngineConfig, pipeline + + +@pytest.mark.order(8) +@pytest.mark.pipeline_func +@pytest.mark.timeout(600) +class TestPipelineLongtextFunc: + + def test_long_test_chat_7b(self, config): + model = 'internlm/internlm2-chat-7b' + tp_config = get_tp_num(config, model) + model_path = '/'.join([config.get('model_path'), model]) + + backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, + session_len=210000, + tp=tp_config) + pipe = pipeline(model_path, backend_config=backend_config) + prompt = '今 天 心 ' * int(200000 / 6) + + # batch infer + pipe(prompt) + + # stream infer + for outputs in pipe.stream_infer(prompt): + continue + + prompts = ['今 天 心 ' * int(200000 / 6)] * 2 + # batch infer + pipe(prompts) + + # stream infer + for outputs in pipe.stream_infer(prompts): + continue + + def test_long_test_chat_20b(self, config): + model = 'internlm/internlm2-chat-20b' + tp_config = get_tp_num(config, model) + model_path = '/'.join([config.get('model_path'), model]) + + backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, + session_len=210000, + tp=tp_config) + pipe = pipeline(model_path, backend_config=backend_config) + prompt = '今 天 心 ' * int(200000 / 6) + + # batch infer + pipe(prompt) + + # stream infer + for outputs in pipe.stream_infer(prompt): + continue + + prompts = ['今 天 心 ' * int(200000 / 6)] * 2 + # batch infer + pipe(prompts) + + # stream infer + for outputs in pipe.stream_infer(prompts): + continue + + def test_long_test_20b(self, config): + model = 'internlm/internlm2-20b' + tp_config = get_tp_num(config, model) + model_path = '/'.join([config.get('model_path'), model]) + + backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, + session_len=210000, + tp=tp_config) + pipe = pipeline(model_path, backend_config=backend_config) + prompt = '今 天 心 ' * int(200000 / 6) + + # batch infer + pipe(prompt) + + # stream infer + for outputs in pipe.stream_infer(prompt): + continue + + prompts = ['今 天 心 ' * int(200000 / 6)] * 2 + # batch infer + pipe(prompts) + + # stream infer + for outputs in pipe.stream_infer(prompts): + continue diff --git a/autotest/prompt_case.yaml b/autotest/prompt_case.yaml index e1839ce3f2..ce5d174518 100644 --- a/autotest/prompt_case.yaml +++ b/autotest/prompt_case.yaml @@ -77,6 +77,9 @@ chinese_poem_case: - internlm2-20b: - len_g: 5 + - falcon: + - len_g: + 5 english_poem_case: - write a romantic English poem: - contain: @@ -110,6 +113,7 @@ emoji_case: - \u2714 - 赞 - emoji + - '!' traditional_chinese_case: - 使用繁體介紹香港維多利亞港: - contain: diff --git a/autotest/tools/chat/test_command_chat_hf_pytorch.py b/autotest/tools/chat/test_command_chat_hf_pytorch.py index 5854584122..f0f8e1c8b3 100644 --- a/autotest/tools/chat/test_command_chat_hf_pytorch.py +++ b/autotest/tools/chat/test_command_chat_hf_pytorch.py @@ -1,7 +1,8 @@ import allure import conftest import pytest -from utils.config_utils import get_torch_model_list +from utils.config_utils import (get_cuda_prefix_by_workerid, + get_torch_model_list) from utils.run_client_chat import hf_command_line_test conftest._init_cli_case_list() @@ -15,12 +16,40 @@ def getCaseList(): @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.hf_pytorch_chat +@pytest.mark.gpu_num_1 @pytest.mark.parametrize('usercase', getCaseList()) -@pytest.mark.parametrize('model', get_torch_model_list()) -def test_hf_pytorch_chat(config, model, cli_case_config, usercase): - result, chat_log, msg = hf_command_line_test(config, usercase, - cli_case_config.get(usercase), - model, 'torch') +@pytest.mark.parametrize('model', get_torch_model_list(tp_num=1)) +def test_hf_pytorch_chat_tp1(config, model, cli_case_config, usercase, + worker_id): + result, chat_log, msg = hf_command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'torch', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id)) + if chat_log is not None: + allure.attach.file(chat_log, + attachment_type=allure.attachment_type.TEXT) + + assert result, msg + + +@pytest.mark.order(10) +@pytest.mark.usefixtures('cli_case_config') +@pytest.mark.hf_pytorch_chat +@pytest.mark.gpu_num_2 +@pytest.mark.parametrize('usercase', getCaseList()) +@pytest.mark.parametrize('model', get_torch_model_list(tp_num=2)) +def test_hf_pytorch_chat_tp2(config, model, cli_case_config, usercase, + worker_id): + result, chat_log, msg = hf_command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'torch', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=2)) if chat_log is not None: allure.attach.file(chat_log, attachment_type=allure.attachment_type.TEXT) @@ -34,7 +63,7 @@ def test_hf_pytorch_chat(config, model, cli_case_config, usercase): @pytest.mark.pr_test @pytest.mark.xdist_group(name='pr_test') @pytest.mark.parametrize('usercase', getCaseList()) -@pytest.mark.parametrize('model', ['internlm2-chat-20b']) +@pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_hf_pytorch_chat_pr(config, model, cli_case_config, usercase): result, chat_log, msg = hf_command_line_test( config, diff --git a/autotest/tools/chat/test_command_chat_hf_turbomind.py b/autotest/tools/chat/test_command_chat_hf_turbomind.py index 3c889fd26d..3e763c0ef2 100644 --- a/autotest/tools/chat/test_command_chat_hf_turbomind.py +++ b/autotest/tools/chat/test_command_chat_hf_turbomind.py @@ -1,7 +1,8 @@ import allure import conftest import pytest -from utils.config_utils import get_turbomind_model_list +from utils.config_utils import (get_cuda_prefix_by_workerid, + get_turbomind_model_list) from utils.run_client_chat import hf_command_line_test conftest._init_cli_case_list() @@ -15,12 +16,41 @@ def getCaseList(): @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.hf_turbomind_chat +@pytest.mark.gpu_num_1 @pytest.mark.parametrize('usercase', getCaseList()) -@pytest.mark.parametrize('model', get_turbomind_model_list()) -def test_hf_turbomind_chat(config, model, cli_case_config, usercase): - result, chat_log, msg = hf_command_line_test(config, usercase, - cli_case_config.get(usercase), - model, 'turbomind') +@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=1)) +def test_hf_turbomind_chat_tp1(config, model, cli_case_config, usercase, + worker_id): + result, chat_log, msg = hf_command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'turbomind', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id)) + + if chat_log is not None: + allure.attach.file(chat_log, + attachment_type=allure.attachment_type.TEXT) + + assert result, msg + + +@pytest.mark.order(10) +@pytest.mark.usefixtures('cli_case_config') +@pytest.mark.hf_turbomind_chat +@pytest.mark.gpu_num_2 +@pytest.mark.parametrize('usercase', getCaseList()) +@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=2)) +def test_hf_turbomind_chat_tp2(config, model, cli_case_config, usercase, + worker_id): + result, chat_log, msg = hf_command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'turbomind', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=2)) if chat_log is not None: allure.attach.file(chat_log, @@ -36,7 +66,8 @@ def test_hf_turbomind_chat(config, model, cli_case_config, usercase): @pytest.mark.xdist_group(name='pr_test') @pytest.mark.parametrize('usercase', getCaseList()) @pytest.mark.parametrize( - 'model', ['internlm2-chat-20b', 'internlm2-chat-20b-inner-w4a16']) + 'model', + ['internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-inner-w4a16']) def test_hf_turbomind_chat_pr(config, model, cli_case_config, usercase): result, chat_log, msg = hf_command_line_test( config, diff --git a/autotest/tools/chat/test_command_chat_workspace.py b/autotest/tools/chat/test_command_chat_workspace.py index 34f0608783..26afeaf998 100644 --- a/autotest/tools/chat/test_command_chat_workspace.py +++ b/autotest/tools/chat/test_command_chat_workspace.py @@ -1,7 +1,8 @@ import allure import conftest import pytest -from utils.config_utils import get_turbomind_model_list +from utils.config_utils import (get_cuda_prefix_by_workerid, + get_turbomind_model_list) from utils.run_client_chat import command_line_test conftest._init_cli_case_list() @@ -12,9 +13,9 @@ def getPromptCaseList(): return prompt_list -def getModelList(): +def getModelList(tp_num): return [ - item for item in get_turbomind_model_list() + item for item in get_turbomind_model_list(tp_num) if 'kvint8' not in item.lower() ] @@ -22,12 +23,39 @@ def getModelList(): @pytest.mark.order(10) @pytest.mark.usefixtures('cli_case_config') @pytest.mark.command_chat +@pytest.mark.gpu_num_1 @pytest.mark.parametrize('usercase', getPromptCaseList()) -@pytest.mark.parametrize('model', getModelList()) -def test_workspace_chat(config, cli_case_config, usercase, model): - result, chat_log, msg = command_line_test(config, usercase, - cli_case_config.get(usercase), - model, 'turbomind', None) +@pytest.mark.parametrize('model', getModelList(tp_num=1)) +def test_workspace_chat_tp1(config, cli_case_config, usercase, model, + worker_id): + result, chat_log, msg = command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'turbomind', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id)) + if chat_log is not None: + allure.attach.file(chat_log, + attachment_type=allure.attachment_type.TEXT) + assert result, msg + + +@pytest.mark.order(10) +@pytest.mark.usefixtures('cli_case_config') +@pytest.mark.command_chat +@pytest.mark.gpu_num_2 +@pytest.mark.parametrize('usercase', getPromptCaseList()) +@pytest.mark.parametrize('model', getModelList(tp_num=2)) +def test_workspace_chat_tp2(config, cli_case_config, usercase, model, + worker_id): + result, chat_log, msg = command_line_test( + config, + usercase, + cli_case_config.get(usercase), + model, + 'turbomind', + cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=2)) if chat_log is not None: allure.attach.file(chat_log, attachment_type=allure.attachment_type.TEXT) @@ -38,10 +66,10 @@ def test_workspace_chat(config, cli_case_config, usercase, model): @pytest.mark.usefixtures('cli_case_config') @pytest.mark.command_chat @pytest.mark.pr_test -@pytest.mark.xdist_group(name='pr_test') @pytest.mark.parametrize('usercase', getPromptCaseList()) @pytest.mark.parametrize( - 'model', ['internlm2-chat-20b', 'internlm2-chat-20b-inner-w4a16']) + 'model', + ['internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-inner-w4a16']) def test_workspace_chat_pr(config, cli_case_config, usercase, model): result, chat_log, msg = command_line_test( config, diff --git a/autotest/tools/convert/test_convert.py b/autotest/tools/convert/test_convert.py index 074ed8f93d..8fe8d30949 100644 --- a/autotest/tools/convert/test_convert.py +++ b/autotest/tools/convert/test_convert.py @@ -4,15 +4,16 @@ import allure import pytest -from utils.config_utils import get_turbomind_model_list +from utils.config_utils import (get_cuda_prefix_by_workerid, + get_turbomind_model_list) from utils.get_run_config import get_command_with_extra, get_model_name @pytest.mark.order(5) @pytest.mark.convert @pytest.mark.parametrize('model', get_turbomind_model_list()) -def test_convert(config, model): - convert(config, model) +def test_convert(config, model, worker_id): + convert(config, model, get_cuda_prefix_by_workerid(worker_id)) @pytest.mark.order(5) @@ -20,32 +21,40 @@ def test_convert(config, model): @pytest.mark.pr_test @pytest.mark.xdist_group(name='pr_test') @pytest.mark.parametrize( - 'model', ['internlm2-chat-20b', 'internlm2-chat-20b-inner-w4a16']) + 'model', + ['internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-inner-w4a16']) def test_convert_pr(config, model): - convert(config, model) + convert(config, model, 'CUDA_VISIBLE_DEVICES=5') -def convert(config, model_case): +def convert(config, model_case, cuda_prefix): origin_model_path = config.get('model_path') + '/' + model_case dst_path = config.get('dst_path') + '/workspace_' + model_case log_path = config.get('log_path') model_name = get_model_name(model_case) - if 'w4' in model_case: - cmd = get_command_with_extra( - ' '.join([ - 'lmdeploy convert', model_name, origin_model_path, - '--dst-path', dst_path, '--model-format awq --group-size 128' - ]), config, model_name, True) + if 'w4' in model_case or '4bits' in model_case: + cmd = get_command_with_extra(' '.join([ + 'lmdeploy convert', model_name, origin_model_path, '--dst-path', + dst_path, '--model-format awq --group-size 128' + ]), + config, + model_name, + True, + cuda_prefix=cuda_prefix) else: - cmd = get_command_with_extra( - ' '.join([ - 'lmdeploy convert', model_name, origin_model_path, - '--dst-path', dst_path - ]), config, model_name, True) + cmd = get_command_with_extra(' '.join([ + 'lmdeploy convert', model_name, origin_model_path, '--dst-path', + dst_path + ]), + config, + model_name, + True, + cuda_prefix=cuda_prefix) - convert_log = os.path.join(log_path, 'convert_' + model_case + '.log') + convert_log = os.path.join(log_path, + 'convert_' + model_case.split('/')[1] + '.log') print('reproduce command convert: ' + cmd + '\n') with open(convert_log, 'w') as f: # remove existing workspace diff --git a/autotest/tools/pipeline/pipeline_chat_script.py b/autotest/tools/pipeline/pipeline_chat_script.py index 70d4abcb36..f8a92d9b8f 100644 --- a/autotest/tools/pipeline/pipeline_chat_script.py +++ b/autotest/tools/pipeline/pipeline_chat_script.py @@ -30,7 +30,8 @@ def run_pipeline_chat_test(config, cases_info, model_case, tp, type): if 'pytorch' == type: backend_config = PytorchEngineConfig(tp=tp) else: - if 'kvint8' in model_case and 'w4' in model_case: + if 'kvint8' in model_case and ('w4' in model_case + or '4bits' in model_case): backend_config = TurbomindEngineConfig(tp=tp, model_format='awq', quant_policy=4) @@ -38,7 +39,7 @@ def run_pipeline_chat_test(config, cases_info, model_case, tp, type): backend_config = TurbomindEngineConfig(tp=tp, model_format='hf', quant_policy=4) - elif 'w4' in model_case: + elif 'w4' in model_case or '4bits' in model_case: backend_config = TurbomindEngineConfig(tp=tp, model_format='awq') else: backend_config = TurbomindEngineConfig(tp=tp) diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch.py index 5014f3a163..7e0318eebd 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_pytorch.py +++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch.py @@ -1,14 +1,15 @@ +import os from multiprocessing import Process import pytest -from utils.config_utils import get_torch_model_list +from utils.config_utils import get_cuda_id_by_workerid, get_torch_model_list from utils.pipeline_chat import (assert_pipeline_chat_log, run_pipeline_chat_test) -def getModelList(): +def getModelList(tp_num): return [ - item for item in get_torch_model_list() + item for item in get_torch_model_list(tp_num) if 'falcon' not in item.lower() and 'chatglm2' not in item.lower() ] @@ -16,9 +17,32 @@ def getModelList(): @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat_pytorch +@pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) -@pytest.mark.parametrize('model', getModelList()) -def test_pipeline_chat_pytorch(config, common_case_config, model): +@pytest.mark.parametrize('model', getModelList(tp_num=1)) +def test_pipeline_chat_pytorch_tp1(config, common_case_config, model, + worker_id): + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) + p = Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, 'pytorch')) + p.start() + p.join() + + # assert script + assert_pipeline_chat_log(config, common_case_config, model) + + +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat_pytorch +@pytest.mark.gpu_num_2 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', getModelList(tp_num=2)) +def test_pipeline_chat_pytorch_tp2(config, common_case_config, model, + worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=2) p = Process(target=run_pipeline_chat_test, args=(config, common_case_config, model, 'pytorch')) p.start() @@ -33,7 +57,7 @@ def test_pipeline_chat_pytorch(config, common_case_config, model): @pytest.mark.pipeline_chat_pytorch @pytest.mark.flaky(reruns=0) @pytest.mark.pr_test -@pytest.mark.parametrize('model', ['internlm2-chat-20b']) +@pytest.mark.parametrize('model', ['internlm/internlm2-chat-20b']) def test_pipeline_chat_pytorch_pr(config, common_case_config, model): p = Process(target=run_pipeline_chat_test, args=(config, common_case_config, model, 'pytorch')) diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind.py index 90773d39bc..e12db44c1e 100644 --- a/autotest/tools/pipeline/test_pipeline_chat_turbomind.py +++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind.py @@ -1,7 +1,8 @@ +import os from multiprocessing import Process import pytest -from utils.config_utils import get_turbomind_model_list +from utils.config_utils import get_all_model_list, get_cuda_id_by_workerid from utils.pipeline_chat import (assert_pipeline_chat_log, run_pipeline_chat_test) @@ -9,9 +10,29 @@ @pytest.mark.order(6) @pytest.mark.usefixtures('common_case_config') @pytest.mark.pipeline_chat +@pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) -@pytest.mark.parametrize('model', get_turbomind_model_list()) -def test_pipeline_chat(config, common_case_config, model): +@pytest.mark.parametrize('model', get_all_model_list(tp_num=1)) +def test_pipeline_chat_tp1(config, common_case_config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id) + p = Process(target=run_pipeline_chat_test, + args=(config, common_case_config, model, 'turbomind')) + p.start() + p.join() + assert_pipeline_chat_log(config, common_case_config, model) + + +@pytest.mark.order(6) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.pipeline_chat +@pytest.mark.gpu_num_2 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('model', get_all_model_list(tp_num=2)) +def test_pipeline_chat_tp2(config, common_case_config, model, worker_id): + if 'gw' in worker_id: + os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id, + tp_num=2) p = Process(target=run_pipeline_chat_test, args=(config, common_case_config, model, 'turbomind')) p.start() @@ -25,7 +46,8 @@ def test_pipeline_chat(config, common_case_config, model): @pytest.mark.flaky(reruns=0) @pytest.mark.pr_test @pytest.mark.parametrize( - 'model', ['internlm2-chat-20b', 'internlm2-chat-20b-inner-w4a16']) + 'model', + ['internlm/internlm2-chat-20b', 'internlm/internlm2-chat-20b-inner-w4a16']) def test_pipeline_chat_pr(config, common_case_config, model): p = Process(target=run_pipeline_chat_test, args=(config, common_case_config, model, 'turbomind')) diff --git a/autotest/tools/quantization/test_quantization_kvint8.py b/autotest/tools/quantization/test_quantization_kvint8.py index 77957a676a..7c57d766e0 100644 --- a/autotest/tools/quantization/test_quantization_kvint8.py +++ b/autotest/tools/quantization/test_quantization_kvint8.py @@ -2,23 +2,23 @@ import allure import pytest +from utils.config_utils import get_cuda_prefix_by_workerid from utils.quantization_utils import quantization -model_list = [('llama-2-7b-chat', 'CUDA_VISIBLE_DEVICES=1'), - ('internlm-chat-20b', 'CUDA_VISIBLE_DEVICES=2'), - ('internlm2-chat-20b', 'CUDA_VISIBLE_DEVICES=3'), - ('Qwen-7B-Chat', 'CUDA_VISIBLE_DEVICES=4'), - ('Qwen-14B-Chat', 'CUDA_VISIBLE_DEVICES=5'), - ('internlm2-20b', 'CUDA_VISIBLE_DEVICES=6'), - ('Baichuan2-7B-Chat', 'CUDA_VISIBLE_DEVICES=7')] +model_list = [ + 'meta-llama/Llama-2-7b-chat', 'internlm/internlm-chat-20b', + 'internlm/internlm2-chat-20b', 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', + 'internlm/internlm2-20b', 'baichuan-inc/Baichuan2-7B-Chat' +] @pytest.mark.order(1) @pytest.mark.quantization_kvint8 @pytest.mark.timeout(900) -@pytest.mark.parametrize('model, prefix', model_list) -def test_quantization_kvint8(config, model, prefix): - quantization_kvint8(config, model + '-inner-kvint8', model, prefix) +@pytest.mark.parametrize('model', model_list) +def test_quantization_kvint8(config, model, worker_id): + quantization_kvint8(config, model + '-inner-kvint8', model, + get_cuda_prefix_by_workerid(worker_id)) def quantization_kvint8(config, quantization_model_name, origin_model_name, @@ -29,9 +29,10 @@ def quantization_kvint8(config, quantization_model_name, origin_model_name, cuda_prefix) log_path = config.get('log_path') quantization_log = os.path.join( - log_path, - '_'.join(['quantization', quantization_type, quantization_model_name - ]) + '.log') + log_path, '_'.join([ + 'quantization', quantization_type, + quantization_model_name.split('/')[1] + ]) + '.log') allure.attach.file(quantization_log, attachment_type=allure.attachment_type.TEXT) diff --git a/autotest/tools/quantization/test_quantization_kvint8_w4a16.py b/autotest/tools/quantization/test_quantization_kvint8_w4a16.py index 9a1c5b6555..44dc1751fb 100644 --- a/autotest/tools/quantization/test_quantization_kvint8_w4a16.py +++ b/autotest/tools/quantization/test_quantization_kvint8_w4a16.py @@ -2,23 +2,26 @@ import allure import pytest +from utils.config_utils import get_cuda_prefix_by_workerid from utils.quantization_utils import quantization -model_list = [('llama-2-7b-chat-inner-kvint8', 'CUDA_VISIBLE_DEVICES=1'), - ('internlm-chat-20b-inner-kvint8', 'CUDA_VISIBLE_DEVICES=2'), - ('internlm2-chat-20b-inner-kvint8', 'CUDA_VISIBLE_DEVICES=3'), - ('Qwen-7B-Chat-inner-kvint8', 'CUDA_VISIBLE_DEVICES=4'), - ('Qwen-14B-Chat-inner-kvint8', 'CUDA_VISIBLE_DEVICES=5'), - ('internlm2-20b-inner-kvint8', 'CUDA_VISIBLE_DEVICES=6'), - ('Baichuan2-7B-Chat-inner-kvint8', 'CUDA_VISIBLE_DEVICES=7')] +model_list = [ + 'meta-llama/Llama-2-7b-chat-inner-kvint8', + 'internlm/internlm-chat-20b-inner-kvint8', + 'internlm/internlm2-chat-20b-inner-kvint8', + 'Qwen/Qwen-7B-Chat-inner-kvint8', 'Qwen/Qwen-14B-Chat-inner-kvint8', + 'internlm/internlm2-20b-inner-kvint8', + 'baichuan-inc/Baichuan2-7B-Chat-inner-kvint8' +] @pytest.mark.order(4) @pytest.mark.quantization_kvint8_w4a16 @pytest.mark.timeout(900) -@pytest.mark.parametrize('model, prefix', model_list) -def test_quantization_kvint8_w4a16(config, model, prefix): - quantization_kvint8(config, model + '-w4a16', model, prefix) +@pytest.mark.parametrize('model', model_list) +def test_quantization_kvint8_w4a16(config, model, worker_id): + quantization_kvint8(config, model + '-w4a16', model, + get_cuda_prefix_by_workerid(worker_id)) def quantization_kvint8(config, quantization_model_name, origin_model_name, @@ -29,9 +32,10 @@ def quantization_kvint8(config, quantization_model_name, origin_model_name, cuda_prefix) log_path = config.get('log_path') quantization_log = os.path.join( - log_path, - '_'.join(['quantization', quantization_type, quantization_model_name - ]) + '.log') + log_path, '_'.join([ + 'quantization', quantization_type, + quantization_model_name.split('/')[1] + ]) + '.log') allure.attach.file(quantization_log, attachment_type=allure.attachment_type.TEXT) diff --git a/autotest/tools/quantization/test_quantization_w4a16.py b/autotest/tools/quantization/test_quantization_w4a16.py index 15749ba70b..3bafadd494 100644 --- a/autotest/tools/quantization/test_quantization_w4a16.py +++ b/autotest/tools/quantization/test_quantization_w4a16.py @@ -2,32 +2,34 @@ import allure import pytest +from utils.config_utils import get_cuda_prefix_by_workerid from utils.quantization_utils import quantization -model_list = [('llama-2-7b-chat', 'CUDA_VISIBLE_DEVICES=0'), - ('internlm-chat-20b', 'CUDA_VISIBLE_DEVICES=1'), - ('Qwen-7B-Chat', 'CUDA_VISIBLE_DEVICES=2'), - ('Qwen-14B-Chat', 'CUDA_VISIBLE_DEVICES=3'), - ('Qwen-VL', 'CUDA_VISIBLE_DEVICES=4'), - ('internlm2-chat-20b', 'CUDA_VISIBLE_DEVICES=5'), - ('internlm2-20b', 'CUDA_VISIBLE_DEVICES=6'), - ('Baichuan2-7B-Chat', 'CUDA_VISIBLE_DEVICES=7')] +model_list = [ + 'meta-llama/Llama-2-7b-chat', 'internlm/internlm-chat-20b', + 'Qwen/Qwen-7B-Chat', 'Qwen/Qwen-14B-Chat', 'Qwen/Qwen-VL', + 'internlm/internlm2-chat-20b', 'internlm/internlm2-20b', + 'baichuan-inc/Baichuan2-7B-Chat' +] @pytest.mark.order(3) @pytest.mark.quantization_w4a16 @pytest.mark.timeout(900) -@pytest.mark.parametrize('model, prefix', model_list) -def test_quantization_w4a16(config, model, prefix): - quantization_w4a16(config, model + '-inner-w4a16', model, prefix) +@pytest.mark.parametrize('model', model_list) +def test_quantization_w4a16(config, model, worker_id): + quantization_w4a16(config, model + '-inner-w4a16', model, + get_cuda_prefix_by_workerid(worker_id)) @pytest.mark.order(3) @pytest.mark.quantization_w4a16 @pytest.mark.pr_test +@pytest.mark.flaky(reruns=0) @pytest.mark.timeout(900) -@pytest.mark.parametrize('model, prefix', - [('internlm2-chat-20b', 'CUDA_VISIBLE_DEVICES=5')]) +@pytest.mark.parametrize( + 'model, prefix', + [('internlm/internlm2-chat-20b', 'CUDA_VISIBLE_DEVICES=5')]) def test_quantization_w4a16_pr(config, model, prefix): quantization_w4a16(config, model + '-inner-w4a16', model, prefix) @@ -40,9 +42,10 @@ def quantization_w4a16(config, quantization_model_name, origin_model_name, cuda_prefix) log_path = config.get('log_path') quantization_log = os.path.join( - log_path, - '_'.join(['quantization', quantization_type, quantization_model_name - ]) + '.log') + log_path, '_'.join([ + 'quantization', quantization_type, + quantization_model_name.split('/')[1] + ]) + '.log') allure.attach.file(quantization_log, attachment_type=allure.attachment_type.TEXT) diff --git a/autotest/tools/quantization/test_quantization_w8a8.py b/autotest/tools/quantization/test_quantization_w8a8.py index 7e6690d423..37a198c6d5 100644 --- a/autotest/tools/quantization/test_quantization_w8a8.py +++ b/autotest/tools/quantization/test_quantization_w8a8.py @@ -2,25 +2,23 @@ import allure import pytest +from utils.config_utils import get_cuda_prefix_by_workerid from utils.quantization_utils import quantization -model_list = [('llama-2-7b-chat', 'CUDA_VISIBLE_DEVICES=0'), - ('internlm-chat-20b', 'CUDA_VISIBLE_DEVICES=1'), - ('internlm2-chat-20b', 'CUDA_VISIBLE_DEVICES=2'), - ('internlm2-chat-7b', 'CUDA_VISIBLE_DEVICES=3'), - ('Yi-6B-Chat', 'CUDA_VISIBLE_DEVICES=4'), - ('internlm2-20b', 'CUDA_VISIBLE_DEVICES=5')] - -# ('Baichuan2-7B-Chat', 'CUDA_VISIBLE_DEVICES=6') -# ('Baichuan2-13B-Chat', 'CUDA_VISIBLE_DEVICES=7') +model_list = [ + 'meta-llama/Llama-2-7b-chat', 'internlm/internlm-chat-20b', + 'internlm/internlm2-chat-20b', 'internlm/internlm2-chat-7b', + '01-ai/Yi-6B-Chat', 'internlm/internlm2-20b' +] @pytest.mark.order(2) @pytest.mark.quantization_w8a8 @pytest.mark.timeout(900) -@pytest.mark.parametrize('model, prefix', model_list) -def test_quantization_w8a8(config, model, prefix): - quantization_w8a8(config, model + '-inner-w8a8', model, prefix) +@pytest.mark.parametrize('model', model_list) +def test_quantization_w8a8(config, model, worker_id): + quantization_w8a8(config, model + '-inner-w8a8', model, + get_cuda_prefix_by_workerid(worker_id)) def quantization_w8a8(config, quantization_model_name, origin_model_name, @@ -31,9 +29,10 @@ def quantization_w8a8(config, quantization_model_name, origin_model_name, cuda_prefix) log_path = config.get('log_path') quantization_log = os.path.join( - log_path, - '_'.join(['quantization', quantization_type, quantization_model_name - ]) + '.log') + log_path, '_'.join([ + 'quantization', quantization_type, + quantization_model_name.split('/')[1] + ]) + '.log') allure.attach.file(quantization_log, attachment_type=allure.attachment_type.TEXT) diff --git a/autotest/tools/restful/test_restful_chat_pytorch.py b/autotest/tools/restful/test_restful_chat_pytorch.py index 1b6fe5607a..6c5b33aa3f 100644 --- a/autotest/tools/restful/test_restful_chat_pytorch.py +++ b/autotest/tools/restful/test_restful_chat_pytorch.py @@ -5,35 +5,50 @@ import allure import pytest from pytest import assume -from utils.config_utils import get_torch_model_list +from utils.config_utils import (get_cuda_prefix_by_workerid, + get_torch_model_list, get_workerid) from utils.get_run_config import get_command_with_extra from utils.run_client_chat import command_line_test from utils.run_restful_chat import (get_model, health_check, interactive_test, open_chat_test) -HTTP_URL = 'http://localhost:23333' +BASE_HTTP_URL = 'http://localhost' +DEFAULT_PORT = 23333 @pytest.fixture(scope='function', autouse=True) -def prepare_environment(request, config): +def prepare_environment(request, config, worker_id): model_path = config.get('model_path') log_path = config.get('log_path') - model = request.param + param = request.param + model = param['model'] + cuda_prefix = param['cuda_prefix'] + tp_num = param['tp_num'] - cmd = ['lmdeploy serve api_server ' + model_path + '/' + model] + if cuda_prefix is None: + cuda_prefix = get_cuda_prefix_by_workerid(worker_id, tp_num=tp_num) + + worker_num = get_workerid(worker_id) + if worker_num is None: + port = DEFAULT_PORT + else: + port = DEFAULT_PORT + worker_num cmd = get_command_with_extra('lmdeploy serve api_server ' + model_path + - '/' + model + ' --backend pytorch', + '/' + model + ' --backend pytorch' + + ' --server-port ' + str(port), config, model, need_tp=True) - start_log = os.path.join(log_path, 'start_restful_' + model + '.log') + print('reproduce command restful: ' + cmd) + + start_log = os.path.join(log_path, + 'start_restful_' + model.split('/')[1] + '.log') with open(start_log, 'w') as f: f.writelines('reproduce command restful: ' + cmd + '\n') - print('reproduce command restful: ' + cmd) # convert convertRes = subprocess.Popen([cmd], @@ -45,7 +60,7 @@ def prepare_environment(request, config): pid = convertRes.pid allure.attach.file(start_log, attachment_type=allure.attachment_type.TEXT) - http_url = HTTP_URL + http_url = BASE_HTTP_URL + ':' + str(port) start_time = int(time()) sleep(5) for i in range(120): @@ -58,42 +73,69 @@ def prepare_environment(request, config): yield if pid > 0: - kill_log = os.path.join(log_path, 'kill_' + model + '.log') + kill_log = os.path.join(log_path, + 'kill_' + model.split('/')[1] + '.log') - subprocess.Popen([ - "ps -ef | grep multiprocessing | grep -v grep | awk '{print $2}' " - + '| xargs kill -9' - ], - shell=True, - text=True, - encoding='utf-8') with open(kill_log, 'w') as f: convertRes.kill() allure.attach.file(kill_log, attachment_type=allure.attachment_type.TEXT) -def getModelList(): - return [ - item for item in get_torch_model_list() if 'chat' in item.lower() - and 'falcon' not in item.lower() and 'chatglm2' not in item.lower() - ] +def getModelList(tp_num): + return [{ + 'model': item, + 'cuda_prefix': None, + 'tp_num': tp_num + } for item in get_torch_model_list(tp_num) if 'chat' in item.lower()] @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api_pytorch +@pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) -@pytest.mark.parametrize('prepare_environment', getModelList(), indirect=True) -def test_restful_chat(config, common_case_config): - run_all_step(config, common_case_config) +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=1), + indirect=True) +def test_restful_chat_tp1(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) -def run_all_step(config, cases_info): - http_url = HTTP_URL +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api_pytorch +@pytest.mark.gpu_num_2 +@pytest.mark.flaky(reruns=0) +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=2), + indirect=True) +def test_restful_chat_tp2(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + +def run_all_step(config, + cases_info, + worker_id: str = 'default', + port: int = DEFAULT_PORT): + http_url = BASE_HTTP_URL + ':' + str(port) model = get_model(http_url) - print(model) + if model is None: + assert False, 'server not start correctly' + for case in cases_info.keys(): if (case == 'memory_test' or case == 'emoji_case') and 'chat' not in model.lower(): @@ -103,15 +145,17 @@ def run_all_step(config, cases_info): with allure.step(case + ' step1 - command chat regression'): chat_result, chat_log, msg = command_line_test( - config, case, case_info, model, 'api_client', http_url) - allure.attach.file(chat_log, - attachment_type=allure.attachment_type.TEXT) - with assume: - assert chat_result, msg + config, case, case_info, model + worker_id, 'api_client', + http_url) + if chat_log is not None: + allure.attach.file(chat_log, + attachment_type=allure.attachment_type.TEXT) + with assume: + assert chat_result, msg with allure.step(case + ' step2 - restful_test - openai chat'): restful_result, restful_log, msg = open_chat_test( - config, case_info, model, http_url) + config, case_info, model, http_url, worker_id) allure.attach.file(restful_log, attachment_type=allure.attachment_type.TEXT) with assume: @@ -119,7 +163,7 @@ def run_all_step(config, cases_info): with allure.step(case + ' step3 - restful_test - interactive chat'): active_result, interactive_log, msg = interactive_test( - config, case_info, model, http_url) + config, case_info, model, http_url, worker_id) allure.attach.file(interactive_log, attachment_type=allure.attachment_type.TEXT) diff --git a/autotest/tools/restful/test_restful_chat_turbomind.py b/autotest/tools/restful/test_restful_chat_turbomind.py index cad858333a..f442aec10a 100644 --- a/autotest/tools/restful/test_restful_chat_turbomind.py +++ b/autotest/tools/restful/test_restful_chat_turbomind.py @@ -5,41 +5,56 @@ import allure import pytest from pytest import assume -from utils.config_utils import get_turbomind_model_list +from utils.config_utils import (get_all_model_list, + get_cuda_prefix_by_workerid, get_workerid) from utils.get_run_config import get_command_with_extra from utils.run_client_chat import command_line_test from utils.run_restful_chat import (get_model, health_check, interactive_test, open_chat_test) -HTTP_URL = 'http://localhost:23333' +BASE_HTTP_URL = 'http://localhost' +DEFAULT_PORT = 23333 @pytest.fixture(scope='function', autouse=True) -def prepare_environment(request, config): +def prepare_environment(request, config, worker_id): model_path = config.get('model_path') log_path = config.get('log_path') param = request.param model = param['model'] cuda_prefix = param['cuda_prefix'] + tp_num = param['tp_num'] + + if cuda_prefix is None: + cuda_prefix = get_cuda_prefix_by_workerid(worker_id, tp_num=tp_num) + + worker_num = get_workerid(worker_id) + if worker_num is None: + port = DEFAULT_PORT + else: + port = DEFAULT_PORT + worker_num cmd = ['lmdeploy serve api_server ' + model_path + '/' + model] cmd = get_command_with_extra('lmdeploy serve api_server ' + model_path + - '/' + model, + '/' + model + ' --server-port ' + str(port), config, model, need_tp=True, cuda_prefix=cuda_prefix) - if 'kvint8' in model and 'w4' not in model: - cmd += ' --model-format hf --quant-policy 4' - if 'kvint8' in model and 'w4' in model: + if 'kvint8' in model: cmd += ' --quant-policy 4' - if 'w4' in model: + if 'w4' in model or '4bits' in model: + cmd += ' --model-format awq' + else: + cmd += ' --model-format hf' + if 'w4' in model or '4bits' in model: cmd += ' --model-format awq' - start_log = os.path.join(log_path, 'start_restful_' + model + '.log') + start_log = os.path.join(log_path, + 'start_restful_' + model.split('/')[1] + '.log') print('reproduce command restful: ' + cmd) @@ -56,7 +71,7 @@ def prepare_environment(request, config): pid = convertRes.pid allure.attach.file(start_log, attachment_type=allure.attachment_type.TEXT) - http_url = HTTP_URL + http_url = BASE_HTTP_URL + ':' + str(port) start_time = int(time()) sleep(5) for i in range(120): @@ -68,7 +83,8 @@ def prepare_environment(request, config): break yield if pid > 0: - kill_log = os.path.join(log_path, 'kill_' + model + '.log') + kill_log = os.path.join(log_path, + 'kill_' + model.split('/')[1] + '.log') with open(kill_log, 'w') as f: convertRes.kill() @@ -76,45 +92,79 @@ def prepare_environment(request, config): allure.attach.file(kill_log, attachment_type=allure.attachment_type.TEXT) -def getModelList(): +def getModelList(tp_num): return [{ 'model': item, - 'cuda_prefix': None - } for item in get_turbomind_model_list() if 'chat' in item.lower()] + 'cuda_prefix': None, + 'tp_num': tp_num + } for item in get_all_model_list(tp_num) if 'chat' in item.lower()] @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api +@pytest.mark.gpu_num_1 @pytest.mark.flaky(reruns=0) -@pytest.mark.parametrize('prepare_environment', getModelList(), indirect=True) -def test_restful_chat(config, common_case_config): - run_all_step(config, common_case_config) +@pytest.mark.parametrize('prepare_environment', + getModelList(tp_num=1), + indirect=True) +def test_restful_chat_tp1(request, config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) @pytest.mark.order(7) @pytest.mark.usefixtures('common_case_config') @pytest.mark.restful_api +@pytest.mark.gpu_num_2 @pytest.mark.flaky(reruns=0) -@pytest.mark.pr_test @pytest.mark.parametrize('prepare_environment', - [{ - 'model': 'internlm2-chat-20b', - 'cuda_prefix': 'CUDA_VISIBLE_DEVICES=5,6' - }, { - 'model': 'internlm2-chat-20b-inner-w4a16', - 'cuda_prefix': 'CUDA_VISIBLE_DEVICES=5,6' - }], + getModelList(tp_num=2), + indirect=True) +def test_restful_chat_tp2(config, common_case_config, worker_id): + if get_workerid(worker_id) is None: + run_all_step(config, common_case_config) + else: + run_all_step(config, + common_case_config, + worker_id=worker_id, + port=DEFAULT_PORT + get_workerid(worker_id)) + + +@pytest.mark.order(7) +@pytest.mark.usefixtures('common_case_config') +@pytest.mark.restful_api +@pytest.mark.flaky(reruns=0) +@pytest.mark.pr_test +@pytest.mark.parametrize('prepare_environment', [{ + 'model': 'internlm/internlm2-chat-20b', + 'cuda_prefix': 'CUDA_VISIBLE_DEVICES=5,6', + 'tp_num': 2 +}, { + 'model': 'internlm/internlm2-chat-20b-inner-w4a16', + 'cuda_prefix': 'CUDA_VISIBLE_DEVICES=5,6', + 'tp_num': 2 +}], indirect=True) def test_restful_chat_pr(config, common_case_config): run_all_step(config, common_case_config) -def run_all_step(config, cases_info): - http_url = HTTP_URL +def run_all_step(config, + cases_info, + worker_id: str = 'default', + port: int = DEFAULT_PORT): + http_url = BASE_HTTP_URL + ':' + str(port) model = get_model(http_url) - print(model) + + if model is None: + assert False, 'server not start correctly' for case in cases_info.keys(): if (case == 'memory_test' or case == 'emoji_case') and 'chat' not in model.lower(): diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py index f811e94f22..48905a1063 100644 --- a/autotest/utils/config_utils.py +++ b/autotest/utils/config_utils.py @@ -1,11 +1,11 @@ import os import yaml +from utils.get_run_config import get_tp_num -def get_turbomind_model_list(): +def get_turbomind_model_list(tp_num: int = None): config_path = os.path.join('autotest/config.yaml') - print(config_path) with open(config_path) as f: config = yaml.load(f.read(), Loader=yaml.SafeLoader) @@ -18,12 +18,16 @@ def get_turbomind_model_list(): for key in quatization_case_config.get('kvint8_w4a16'): case_list.append(key + '-inner-kvint8-w4a16') + if tp_num is not None: + return [ + item for item in case_list if get_tp_num(config, item) == tp_num + ] + return case_list -def get_torch_model_list(): +def get_torch_model_list(tp_num: int = None): config_path = os.path.join('autotest/config.yaml') - print(config_path) with open(config_path) as f: config = yaml.load(f.read(), Loader=yaml.SafeLoader) @@ -32,4 +36,64 @@ def get_torch_model_list(): for key in quatization_case_config.get('w8a8'): case_list.append(key + '-inner-w8a8') + if tp_num is not None: + return [ + item for item in case_list if get_tp_num(config, item) == tp_num + ] + return case_list + + +def get_all_model_list(tp_num: int = None): + config_path = os.path.join('autotest/config.yaml') + with open(config_path) as f: + config = yaml.load(f.read(), Loader=yaml.SafeLoader) + + case_list = config.get('turbomind_model') + for key in config.get('pytorch_model'): + if key not in case_list: + case_list.append(key) + quatization_case_config = config.get('quatization_case_config') + for key in quatization_case_config.get('w4a16'): + case_list.append(key + '-inner-w4a16') + for key in quatization_case_config.get('kvint8'): + case_list.append(key + '-inner-kvint8') + for key in quatization_case_config.get('kvint8_w4a16'): + case_list.append(key + '-inner-kvint8-w4a16') + + if tp_num is not None: + return [ + item for item in case_list if get_tp_num(config, item) == tp_num + ] + + return case_list + + +def get_cuda_prefix_by_workerid(worker_id, tp_num: int = 1): + if worker_id is None or 'gw' not in worker_id: + return None + else: + if tp_num == 1: + return 'CUDA_VISIBLE_DEVICES=' + worker_id.replace('gw', '') + elif tp_num == 2: + cuda_num = int(worker_id.replace('gw', '')) * 2 + return 'CUDA_VISIBLE_DEVICES=' + ','.join( + [str(cuda_num), str(cuda_num + 1)]) + + +def get_cuda_id_by_workerid(worker_id, tp_num: int = 1): + if worker_id is None or 'gw' not in worker_id: + return None + else: + if tp_num == 1: + return worker_id.replace('gw', '') + elif tp_num == 2: + cuda_num = int(worker_id.replace('gw', '')) * 2 + return ','.join([str(cuda_num), str(cuda_num + 1)]) + + +def get_workerid(worker_id): + if worker_id is None or 'gw' not in worker_id: + return None + else: + return int(worker_id.replace('gw', '')) diff --git a/autotest/utils/get_run_config.py b/autotest/utils/get_run_config.py index 03c53cf5d5..120446f47d 100644 --- a/autotest/utils/get_run_config.py +++ b/autotest/utils/get_run_config.py @@ -102,12 +102,21 @@ def _get_available_cude(): def _simple_model_name(model): - model_name = model.replace('-inner-w4a16', '') + if '/' in model: + model_name = model.split('/')[1] + else: + model_name = model + model_name = model_name.replace('-inner-w4a16', '') model_name = model_name.replace('-inner-w8a8', '') model_name = model_name.replace('-inner-kvint8', '') model_name = model_name.replace('-w4a16', '') return model_name +def _split_model_name(model): + model_name = model.split('/')[1] + return model_name + + if __name__ == '__main__': - print(_simple_model_name('Baichuan2-7B-Chat-inner-w4a16')) + print(_simple_model_name('baichuan-inc/Baichuan2-7B-Chat-inner-w4a16')) diff --git a/autotest/utils/pipeline_chat.py b/autotest/utils/pipeline_chat.py index 274730d4a1..0d65ad6ae0 100644 --- a/autotest/utils/pipeline_chat.py +++ b/autotest/utils/pipeline_chat.py @@ -27,7 +27,8 @@ def run_pipeline_chat_test(config, cases_info, model_case, type): if 'pytorch' == type: backend_config = PytorchEngineConfig(tp=tp) else: - if 'kvint8' in model_case and 'w4' in model_case: + if 'kvint8' in model_case and ('w4' in model_case + or '4bits' in model_case): backend_config = TurbomindEngineConfig(tp=tp, model_format='awq', quant_policy=4) @@ -35,7 +36,7 @@ def run_pipeline_chat_test(config, cases_info, model_case, type): backend_config = TurbomindEngineConfig(tp=tp, model_format='hf', quant_policy=4) - elif 'w4' in model_case: + elif 'w4' in model_case or '4bits' in model_case: backend_config = TurbomindEngineConfig(tp=tp, model_format='awq') else: backend_config = TurbomindEngineConfig(tp=tp) @@ -43,6 +44,7 @@ def run_pipeline_chat_test(config, cases_info, model_case, type): # run testcases gen_config = GenerationConfig(temperature=0.01) + gen_config = GenerationConfig() for case in cases_info.keys(): if (case == 'memory_test' or case == 'emoji_case') and 'chat' not in model_case.lower(): @@ -50,7 +52,8 @@ def run_pipeline_chat_test(config, cases_info, model_case, type): case_info = cases_info.get(case) pipeline_chat_log = os.path.join( - log_path, 'pipeline_chat_' + model_case + '_' + case + '.log') + log_path, + 'pipeline_chat_' + model_case.split('/')[1] + '_' + case + '.log') file = open(pipeline_chat_log, 'w') @@ -94,7 +97,8 @@ def assert_pipeline_chat_log(config, cases_info, model_case): result = False with allure.step('case - ' + case): pipeline_chat_log = os.path.join( - log_path, 'pipeline_chat_' + model_case + '_' + case + '.log') + log_path, 'pipeline_chat_' + model_case.split('/')[1] + '_' + + case + '.log') with open(pipeline_chat_log, 'r') as f: lines = f.readlines() diff --git a/autotest/utils/quantization_utils.py b/autotest/utils/quantization_utils.py index 759f6a5169..d4b4272713 100644 --- a/autotest/utils/quantization_utils.py +++ b/autotest/utils/quantization_utils.py @@ -13,9 +13,10 @@ def quantization(config, origin_model_path = config.get('model_path') + '/' + origin_model_name quantization_model_path = model_path + '/' + quantization_model_name quantization_log = os.path.join( - log_path, - '_'.join(['quantization', quantization_type, quantization_model_name - ]) + '.log') + log_path, '_'.join([ + 'quantization', quantization_type, + quantization_model_name.split('/')[1] + ]) + '.log') if quantization_type == 'w4a16': quantization_cmd = ' '.join([ diff --git a/autotest/utils/run_client_chat.py b/autotest/utils/run_client_chat.py index cfdd2bfa54..334154a95a 100644 --- a/autotest/utils/run_client_chat.py +++ b/autotest/utils/run_client_chat.py @@ -10,7 +10,7 @@ def command_line_test(config, case_info, model_case, type, - extra, + extra: str = None, cuda_prefix: str = None): dst_path = config.get('dst_path') @@ -24,12 +24,16 @@ def command_line_test(config, config, model_case, cuda_prefix=cuda_prefix) - if 'kvint8' in model_case and 'w4' not in model_case: - cmd += ' --model-format hf --quant-policy 4' - if 'kvint8' in model_case and 'w4' in model_case: + if 'kvint8' in model_case: cmd += ' --quant-policy 4' - if 'w4' in model_case: + if 'w4' in model_case or '4bits' in model_case: + cmd += ' --model-format awq' + else: + cmd += ' --model-format hf' + elif 'w4' in model_case or '4bits' in model_case: cmd += ' --model-format awq' + if 'chat' not in model_case.lower(): + cmd += ' --cap completion' return command_test(config, [cmd], model_case, case, case_info, type == 'turbomind') @@ -48,11 +52,13 @@ def hf_command_line_test(config, need_tp=True, cuda_prefix=cuda_prefix) - if 'kvint8' in model_case and 'w4' not in model_case: - cmd += ' --model-format hf --quant-policy 4' - if 'kvint8' in model_case and 'w4' in model_case: + if 'kvint8' in model_case: cmd += ' --quant-policy 4' - if 'w4' in model_case: + if 'w4' in model_case or '4bits' in model_case: + cmd += ' --model-format awq' + else: + cmd += ' --model-format hf' + elif 'w4' in model_case or '4bits' in model_case: cmd += ' --model-format awq' return command_test(config, [cmd], model_case, '_'.join(['hf', type, case]), case_info, True) @@ -66,8 +72,12 @@ def command_test(config, cmd, model, case, case_info, need_extract_output): log_path = config.get('log_path') model_name = get_model_name(model) - chat_log = os.path.join(log_path, - 'chat_' + model + '_' + case + '.log') + if '/' in model: + chat_log = os.path.join( + log_path, 'chat_' + model.split('/')[1] + '_' + case + '.log') + else: + chat_log = os.path.join(log_path, + 'chat_' + model + '_' + case + '.log') file = open(chat_log, 'w') @@ -78,7 +88,7 @@ def command_test(config, cmd, model, case, case_info, need_extract_output): file.writelines('reproduce command chat: ' + ' '.join(cmd) + '\n') spliter = '\n\n' - if model == 'CodeLlama-7b-Instruct-hf': + if 'CodeLlama-7b-Instruct-hf' in model: spliter = '\n!!\n' # join prompt together prompt = '' @@ -136,15 +146,13 @@ def command_test(config, cmd, model, case, case_info, need_extract_output): # 从输出中解析模型输出的对话内容 def parse_dialogue(inputs: str, model: str): dialogues = inputs.strip() - if model == 'CodeLlama-7b-Instruct-hf': + if 'CodeLlama-7b-Instruct-hf' in model: sep = 'enter !! to end the input >>>' else: sep = 'double enter to end input >>>' dialogues = dialogues.strip() dialogues = dialogues.split(sep) dialogues = [d.strip() for d in dialogues] - if 'Llama' in model: - return dialogues return dialogues[1:-1] # 去除首尾无用字符 diff --git a/autotest/utils/run_restful_chat.py b/autotest/utils/run_restful_chat.py index 6236ebe50c..8d82a4c9c4 100644 --- a/autotest/utils/run_restful_chat.py +++ b/autotest/utils/run_restful_chat.py @@ -7,10 +7,11 @@ from lmdeploy.serve.openai.api_client import APIClient -def open_chat_test(config, case_info, model, url): +def open_chat_test(config, case_info, model, url, worker_id: str = 'default'): log_path = config.get('log_path') - restful_log = os.path.join(log_path, 'restful_' + model + '.log') + restful_log = os.path.join(log_path, + 'restful_' + model + '_' + worker_id + '.log') file = open(restful_log, 'w') @@ -49,10 +50,15 @@ def open_chat_test(config, case_info, model, url): return result, restful_log, msg -def interactive_test(config, case_info, model, url): +def interactive_test(config, + case_info, + model, + url, + worker_id: str = 'default'): log_path = config.get('log_path') - interactive_log = os.path.join(log_path, 'interactive_' + model + '.log') + interactive_log = os.path.join( + log_path, 'interactive_' + model + '_' + worker_id + '.log') file = open(interactive_log, 'w')