From c658f9ae7aa5afeca685d9368f3c3d08a6d3c4d4 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 8 Jan 2024 11:52:55 +0800 Subject: [PATCH] Check-in pytorch engine config (#908) * engine config * rename * fix * fix * update chat --- benchmark/profile_torch_generation.py | 35 ++++----- benchmark/profile_torch_throughput.py | 23 +++--- lmdeploy/pytorch/__init__.py | 3 + lmdeploy/pytorch/chat.py | 108 +++++++++++++++++--------- lmdeploy/pytorch/config.py | 34 +++++++- lmdeploy/pytorch/engine/__init__.py | 4 +- lmdeploy/pytorch/engine/engine.py | 70 +++++++++-------- 7 files changed, 175 insertions(+), 102 deletions(-) diff --git a/benchmark/profile_torch_generation.py b/benchmark/profile_torch_generation.py index f83218632c..de134e73d6 100644 --- a/benchmark/profile_torch_generation.py +++ b/benchmark/profile_torch_generation.py @@ -16,12 +16,11 @@ nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion) from tqdm import tqdm -from lmdeploy.pytorch.messages import SamplingParam - def infer(model, session_id: int, input_ids: List, output_seqlen: int, top_k: int, top_p: float, temperature: float, test_round: int, que: Queue): + from lmdeploy.messages import EngineGenerationConfig if session_id == 1: pbar = tqdm(total=test_round) @@ -45,14 +44,14 @@ def infer(model, session_id: int, input_ids: List, output_seqlen: int, the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0` """ # noqa: E501 # TODO: use same inference interface - sampling_param = SamplingParam(top_k=top_k, - top_p=top_p, - temperature=temperature, - ignore_eos=True) + gen_config = EngineGenerationConfig(max_new_tokens=output_seqlen, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ignore_eos=True) for outputs in chatbot.stream_infer(session_id, input_ids=input_ids, - request_output_len=output_seqlen, - sampling_param=sampling_param): + gen_config=gen_config): if len(outputs) > 1: _, n_token = outputs[-2:] else: @@ -81,18 +80,19 @@ def warmup(model, concurrency: int, input_ids: List[int], output_seqlen: int, print('start to warmup ...') def _infer(model, session_id): + from lmdeploy.messages import EngineGenerationConfig chatbot = model.create_instance() for _ in range(warmup_round): # TODO: use same inference interface - sampling_param = SamplingParam(top_k=1, - top_p=1.0, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=True) + gen_config = EngineGenerationConfig(max_new_tokens=output_seqlen, + top_k=1, + top_p=1.0, + temperature=0.8, + repetition_penalty=1.0, + ignore_eos=True) generator = chatbot.stream_infer(session_id, input_ids=input_ids, - request_output_len=output_seqlen, - sampling_param=sampling_param) + gen_config=gen_config) for _ in generator: continue # for pytorch engine to restart a session @@ -123,10 +123,9 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int, f'n_completion_token: {output_seqlen}, ' f'test_round: {test_round}, warmup_round: {warmup_round}') - from lmdeploy.pytorch.engine import Engine + from lmdeploy.pytorch.engine import Engine, EngineConfig - # tokenizer = Tokenizer(model_path) - tm_model = Engine(model_path, tp=tp, model_name='llama') + tm_model = Engine(model_path, EngineConfig(model_name='llama', tp=tp)) # make up a dummy `input_ids` with the length of `input_seqlen` exactly assert input_seqlen > 0, 'input_seqlen should > 0' diff --git a/benchmark/profile_torch_throughput.py b/benchmark/profile_torch_throughput.py index 46dd59bed5..e5b8142a14 100644 --- a/benchmark/profile_torch_throughput.py +++ b/benchmark/profile_torch_throughput.py @@ -12,8 +12,9 @@ import numpy as np from tqdm import tqdm +from lmdeploy.messages import EngineGenerationConfig from lmdeploy.pytorch.engine import Engine as LMEngine -from lmdeploy.pytorch.messages import SamplingParam +from lmdeploy.pytorch.engine import EngineConfig from lmdeploy.tokenizer import Tokenizer @@ -63,7 +64,8 @@ class Engine: def __init__(self, model_path: str, tp: int, csv: str, **kwargs): # avoid turbomind checking chat template name by setting # `model_name='llama'` - tm_model = LMEngine(model_path, tp=tp, model_name='llama') + tm_model = LMEngine(model_path, EngineConfig(tp=tp, + model_name='llama')) self.tm_model = tm_model self.tokenizer = tm_model.tokenizer self.csv = csv @@ -84,15 +86,14 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, input_ids = self.tokenizer(prompt).input_ids # TODO: share same stream infer - sampling_param = SamplingParam(top_k=1, - top_p=1.0, - temperature=1.0, - ignore_eos=True) - for outputs in model_inst.stream_infer( - session_id, - input_ids=input_ids, - request_output_len=output_seqlen, - sampling_param=sampling_param): + gen_config = EngineGenerationConfig(max_new_tokens=output_seqlen, + top_k=1, + top_p=1.0, + temperature=1.0, + ignore_eos=True) + for outputs in model_inst.stream_infer(session_id, + input_ids=input_ids, + gen_config=gen_config): if len(outputs) > 1: res, n_token = outputs[-2:] else: diff --git a/lmdeploy/pytorch/__init__.py b/lmdeploy/pytorch/__init__.py index ef101fec61..145ef7ca58 100644 --- a/lmdeploy/pytorch/__init__.py +++ b/lmdeploy/pytorch/__init__.py @@ -1 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .config import EngineConfig + +__all__ = ['EngineConfig'] diff --git a/lmdeploy/pytorch/chat.py b/lmdeploy/pytorch/chat.py index 5fd7e5ba92..c18413b5ab 100644 --- a/lmdeploy/pytorch/chat.py +++ b/lmdeploy/pytorch/chat.py @@ -4,11 +4,11 @@ import random from typing import List -from lmdeploy.model import MODELS +from lmdeploy.messages import EngineGenerationConfig +from lmdeploy.model import MODELS, best_match_model +from lmdeploy.pytorch import EngineConfig from lmdeploy.tokenizer import Tokenizer -from .messages import SamplingParam - os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -48,37 +48,36 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer): return stop_words -def main( - model_path, - model_name: str, # can not get model_name from hf model - session_id: int = 1, - top_k=40, - top_p=0.8, - temperature=0.8, - repetition_penalty: float = 1.0, - tp: int = 1, - stream_output=True, - trust_remote_code=True): +def run_chat(model_path, + engine_config: EngineConfig, + gen_config: EngineGenerationConfig = None, + session_id: int = 1, + trust_remote_code=True): """An example to perform model inference through the command line interface. Args: - model_path (str): the huggingface model path - session_id (int): the identical id of a session - repetition_penalty (float): parameter to penalize repetition - tp (int): GPU number used in tensor parallelism - stream_output (bool): indicator for streaming output or not + model_path (str): the huggingface model path. + engine_config (EngineConfig): Config of engine. + gen_config (EngineGenerationConfig): Config of generation. + session_id (int): the identical id of a session. + trust_remote_code (bool): trust remote code. """ - from . import engine as tm - tm_model = tm.Engine(model_path, - tp=tp, - trust_remote_code=trust_remote_code) + from lmdeploy.pytorch.engine import Engine + tm_model = Engine(model_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code) tokenizer = tm_model.tokenizer generator = tm_model.create_instance() nth_round = 1 step = 0 seed = random.getrandbits(64) + model_name = engine_config.model_name + if model_name is None: + model_name = best_match_model(model_path)[0] + assert model_name is not None, 'Can not find match model template' + print(f'match template: <{model_name}>') model = MODELS.get(model_name)() stop_words = _stop_words(model.stop_words, tokenizer) @@ -94,27 +93,21 @@ def main( else: prompt = model.get_prompt(prompt, nth_round == 1) input_ids = tokenizer.encode(prompt, nth_round == 1) - if step >= tm_model.session_len: + session_len = model.session_len + if session_len is None: + session_len = tm_model.session_len + if step >= session_len: print('WARNING: exceed session max length.' ' Please end the session.') continue print(f'{prompt} ', end='', flush=True) response_size = 0 - sampling_param = SamplingParam( - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=False, - random_seed=seed, - stop_words=stop_words) - for outputs in generator.stream_infer( - session_id=session_id, - input_ids=input_ids, - request_output_len=512, - step=step, - sampling_param=sampling_param): + gen_config.random_seed = seed + gen_config.stop_words = stop_words + for outputs in generator.stream_infer(session_id=session_id, + input_ids=input_ids, + gen_config=gen_config): status, res, tokens = outputs # decode res response = tokenizer.decode(res, offset=response_size) @@ -134,6 +127,45 @@ def main( nth_round += 1 +def main(model_path, + model_name: str = None, + session_id: int = 1, + top_k=40, + top_p=0.8, + temperature=0.8, + repetition_penalty: float = 1.0, + tp: int = 1, + stream_output: bool = True, + trust_remote_code=True): + """An example to perform model inference through the command line + interface. + + Args: + model_path (str): the huggingface model path + model_name (str): name of the model. + session_id (int): the identical id of a session + top_k (int): sampling top k. + top_p (int): sampling top p. + temperature (float): sampling temperature. + repetition_penalty (float): parameter to penalize repetition + tp (int): GPU number used in tensor parallelism + stream_output (bool): indicator for streaming output or not + trust_remote_code (bool): Trust remote code. + """ + engine_config = EngineConfig(model_name=model_name, tp=tp) + gen_config = EngineGenerationConfig(max_new_tokens=512, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + ignore_eos=False) + return run_chat(model_path, + engine_config, + gen_config, + session_id=session_id, + trust_remote_code=trust_remote_code) + + if __name__ == '__main__': import fire diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 09c53462fb..5130aa3137 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -2,14 +2,44 @@ from dataclasses import dataclass, field +@dataclass +class EngineConfig: + """PyTorch Engine Config. + + Args: + model_name (str): name of the given model. + tp (int): Tensor Parallelism. default 1. + session_len (int): Max session length. Default None. + max_batch_size: (int): Max batch size. Default 128. + eviction_type (str): What action to perform when kv cache + is full, ['recompute', 'copy'], Default 'recompute'. + prefill_interval (int): Interval to perform prefill, + Default 16. + block_size (int): paging cache block size, default 64. + num_cpu_blocks (int): Num cpu blocks. If num is 0, cache + would be allocate according to current environment. + num_gpu_blocks (int): Num gpu blocks. If num is 0, cache + would be allocate according to current environment. + """ + model_name: str = '' + tp: int = 1 + session_len: int = None + max_batch_size: int = 128 + eviction_type: str = 'recompute' + prefill_interval: int = 16 + block_size: int = 64 + num_cpu_blocks: int = 0 + num_gpu_blocks: int = 0 + + @dataclass class SchedulerConfig: """Config of scheduler.""" max_batches: int max_session_len: int - max_request_output_len: int - eviction_type: str = 'copy' + max_request_output_len: int = 512 + eviction_type: str = 'recompute' prefill_interval: int = 16 diff --git a/lmdeploy/pytorch/engine/__init__.py b/lmdeploy/pytorch/engine/__init__.py index 82de01283b..f6d89abc33 100644 --- a/lmdeploy/pytorch/engine/__init__.py +++ b/lmdeploy/pytorch/engine/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .engine import Engine +from .engine import Engine, EngineConfig -__all__ = ['Engine'] +__all__ = ['Engine', 'EngineConfig'] diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index c9d00a1fe0..82e0811664 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -8,10 +8,11 @@ import torch from transformers import AutoConfig +from lmdeploy.messages import EngineGenerationConfig from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger -from ..config import CacheConfig, ModelConfig, SchedulerConfig +from ..config import CacheConfig, EngineConfig, ModelConfig, SchedulerConfig from ..messages import (MessageStatus, SamplingParam, SchedulerSequence, SchedulerSession) from ..paging import Scheduler @@ -140,32 +141,30 @@ class Engine: Args: model_path (str): The hugging face model path. - scheduler_config (SchedulerConfig): The config of the scheduler. - cache_config (CacheConfig): The config of the cache info. - tp (int): Number of tensor parallel. + engine_config (EngineConfig): The config of the Engine. """ def __init__(self, model_path: str, - scheduler_config: SchedulerConfig = None, - cache_config: CacheConfig = None, - tp: int = 1, - model_name: str = None, + engine_config: EngineConfig, trust_remote_code=True) -> None: + self.engine_config = engine_config + model_name = engine_config.model_name + tp = engine_config.tp self.tp = tp self.gpu_count = tp self.model_name = model_name - scheduler_config = scheduler_config or SchedulerConfig( - max_batches=128, - max_session_len=4096, - max_request_output_len=512, + scheduler_config = SchedulerConfig( + max_batches=engine_config.max_batch_size, + max_session_len=engine_config.session_len, eviction_type='recompute') # block_size = 1 to enable unified paging - cache_config = cache_config or CacheConfig( - block_size=64, num_cpu_blocks=0, num_gpu_blocks=0) + cache_config = CacheConfig(block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks) hf_config = AutoConfig.from_pretrained( model_path, trust_remote_code=trust_remote_code) @@ -416,6 +415,8 @@ def _check_request_len(msg): return msg.remain_output_len <= 0 def _check_session_len(msg, max_session_len): + if max_session_len is None: + return False session_len = msg.logical_blocks.num_tokens() return session_len >= max_session_len @@ -749,10 +750,8 @@ def _try_add_session(self, session_id: int): def stream_infer(self, session_id: int, - input_ids: List[int] = None, - request_output_len: int = None, - step: int = 0, - sampling_param: SamplingParam = SamplingParam(), + input_ids: List[int], + gen_config: EngineGenerationConfig = None, **kwargs): """Send stream inference request. @@ -761,13 +760,26 @@ def stream_infer(self, input_ids (List[int]): The input token ids. request_output_len (int): The max output length of this request. step (int): No use for now. - sampling_param (SamplingParam): The sampling param of the output. + gen_config (EngineGenerationConfig): The sampling parameters. Yields: int: Error flags. 0 if success. List[int]: The streaming output tokens. int: The number of the output tokens. """ + + # TODO: support input embedding, step + gen_config = gen_config or EngineGenerationConfig() + request_output_len = gen_config.max_new_tokens + sampling_param = SamplingParam( + top_p=gen_config.top_p, + top_k=gen_config.top_k, + temperature=gen_config.temperature, + repetition_penalty=gen_config.repetition_penalty, + ignore_eos=gen_config.ignore_eos, + random_seed=gen_config.random_seed, + stop_words=gen_config.stop_words, + bad_words=gen_config.bad_words) self._try_add_session(session_id) msg = dict( token_ids=input_ids, @@ -797,14 +809,11 @@ def stream_infer(self, yield (1, [], 0) break - def infer( - self, - session_id: int, - prompt_token_ids: List[int] = None, - request_output_len: int = None, - step: int = 0, - sampling_param: SamplingParam = SamplingParam(), - ): + def infer(self, + session_id: int, + prompt_token_ids: List[int] = None, + gen_config: EngineGenerationConfig = None, + **kwargs): """Send inference request. Args: @@ -812,7 +821,7 @@ def infer( prompt_token_ids (List[int]): The input token ids. request_output_len (int): The max output length of this request. step (int): No use for now. - sampling_param (SamplingParam): The sampling param of the output. + gen_config (EngineGenerationConfig): The sampling parameters. Returns: int: Error flags. 0 if success. @@ -822,9 +831,8 @@ def infer( token_ids = [] for outputs in self.stream_infer(session_id, prompt_token_ids, - request_output_len=request_output_len, - step=step, - sampling_param=sampling_param): + gen_config=gen_config, + **kwargs): status, tmp_ids, _ = outputs if status != 0: return (status, token_ids, len(token_ids))