Skip to content

Commit

Permalink
Check-in pytorch engine config (#908)
Browse files Browse the repository at this point in the history
* engine config

* rename

* fix

* fix

* update chat
  • Loading branch information
grimoire authored Jan 8, 2024
1 parent 7d115a4 commit c658f9a
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 102 deletions.
35 changes: 17 additions & 18 deletions benchmark/profile_torch_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion)
from tqdm import tqdm

from lmdeploy.pytorch.messages import SamplingParam


def infer(model, session_id: int, input_ids: List, output_seqlen: int,
top_k: int, top_p: float, temperature: float, test_round: int,
que: Queue):
from lmdeploy.messages import EngineGenerationConfig

if session_id == 1:
pbar = tqdm(total=test_round)
Expand All @@ -45,14 +44,14 @@ def infer(model, session_id: int, input_ids: List, output_seqlen: int,
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
# TODO: use same inference interface
sampling_param = SamplingParam(top_k=top_k,
top_p=top_p,
temperature=temperature,
ignore_eos=True)
gen_config = EngineGenerationConfig(max_new_tokens=output_seqlen,
top_k=top_k,
top_p=top_p,
temperature=temperature,
ignore_eos=True)
for outputs in chatbot.stream_infer(session_id,
input_ids=input_ids,
request_output_len=output_seqlen,
sampling_param=sampling_param):
gen_config=gen_config):
if len(outputs) > 1:
_, n_token = outputs[-2:]
else:
Expand Down Expand Up @@ -81,18 +80,19 @@ def warmup(model, concurrency: int, input_ids: List[int], output_seqlen: int,
print('start to warmup ...')

def _infer(model, session_id):
from lmdeploy.messages import EngineGenerationConfig
chatbot = model.create_instance()
for _ in range(warmup_round):
# TODO: use same inference interface
sampling_param = SamplingParam(top_k=1,
top_p=1.0,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=True)
gen_config = EngineGenerationConfig(max_new_tokens=output_seqlen,
top_k=1,
top_p=1.0,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=True)
generator = chatbot.stream_infer(session_id,
input_ids=input_ids,
request_output_len=output_seqlen,
sampling_param=sampling_param)
gen_config=gen_config)
for _ in generator:
continue
# for pytorch engine to restart a session
Expand Down Expand Up @@ -123,10 +123,9 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
f'n_completion_token: {output_seqlen}, '
f'test_round: {test_round}, warmup_round: {warmup_round}')

from lmdeploy.pytorch.engine import Engine
from lmdeploy.pytorch.engine import Engine, EngineConfig

# tokenizer = Tokenizer(model_path)
tm_model = Engine(model_path, tp=tp, model_name='llama')
tm_model = Engine(model_path, EngineConfig(model_name='llama', tp=tp))

# make up a dummy `input_ids` with the length of `input_seqlen` exactly
assert input_seqlen > 0, 'input_seqlen should > 0'
Expand Down
23 changes: 12 additions & 11 deletions benchmark/profile_torch_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import numpy as np
from tqdm import tqdm

from lmdeploy.messages import EngineGenerationConfig
from lmdeploy.pytorch.engine import Engine as LMEngine
from lmdeploy.pytorch.messages import SamplingParam
from lmdeploy.pytorch.engine import EngineConfig
from lmdeploy.tokenizer import Tokenizer


Expand Down Expand Up @@ -63,7 +64,8 @@ class Engine:
def __init__(self, model_path: str, tp: int, csv: str, **kwargs):
# avoid turbomind checking chat template name by setting
# `model_name='llama'`
tm_model = LMEngine(model_path, tp=tp, model_name='llama')
tm_model = LMEngine(model_path, EngineConfig(tp=tp,
model_name='llama'))
self.tm_model = tm_model
self.tokenizer = tm_model.tokenizer
self.csv = csv
Expand All @@ -84,15 +86,14 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,

input_ids = self.tokenizer(prompt).input_ids
# TODO: share same stream infer
sampling_param = SamplingParam(top_k=1,
top_p=1.0,
temperature=1.0,
ignore_eos=True)
for outputs in model_inst.stream_infer(
session_id,
input_ids=input_ids,
request_output_len=output_seqlen,
sampling_param=sampling_param):
gen_config = EngineGenerationConfig(max_new_tokens=output_seqlen,
top_k=1,
top_p=1.0,
temperature=1.0,
ignore_eos=True)
for outputs in model_inst.stream_infer(session_id,
input_ids=input_ids,
gen_config=gen_config):
if len(outputs) > 1:
res, n_token = outputs[-2:]
else:
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .config import EngineConfig

__all__ = ['EngineConfig']
108 changes: 70 additions & 38 deletions lmdeploy/pytorch/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import random
from typing import List

from lmdeploy.model import MODELS
from lmdeploy.messages import EngineGenerationConfig
from lmdeploy.model import MODELS, best_match_model
from lmdeploy.pytorch import EngineConfig
from lmdeploy.tokenizer import Tokenizer

from .messages import SamplingParam

os.environ['TM_LOG_LEVEL'] = 'ERROR'


Expand Down Expand Up @@ -48,37 +48,36 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
return stop_words


def main(
model_path,
model_name: str, # can not get model_name from hf model
session_id: int = 1,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty: float = 1.0,
tp: int = 1,
stream_output=True,
trust_remote_code=True):
def run_chat(model_path,
engine_config: EngineConfig,
gen_config: EngineGenerationConfig = None,
session_id: int = 1,
trust_remote_code=True):
"""An example to perform model inference through the command line
interface.
Args:
model_path (str): the huggingface model path
session_id (int): the identical id of a session
repetition_penalty (float): parameter to penalize repetition
tp (int): GPU number used in tensor parallelism
stream_output (bool): indicator for streaming output or not
model_path (str): the huggingface model path.
engine_config (EngineConfig): Config of engine.
gen_config (EngineGenerationConfig): Config of generation.
session_id (int): the identical id of a session.
trust_remote_code (bool): trust remote code.
"""
from . import engine as tm
tm_model = tm.Engine(model_path,
tp=tp,
trust_remote_code=trust_remote_code)
from lmdeploy.pytorch.engine import Engine
tm_model = Engine(model_path,
engine_config=engine_config,
trust_remote_code=trust_remote_code)
tokenizer = tm_model.tokenizer
generator = tm_model.create_instance()

nth_round = 1
step = 0
seed = random.getrandbits(64)
model_name = engine_config.model_name
if model_name is None:
model_name = best_match_model(model_path)[0]
assert model_name is not None, 'Can not find match model template'
print(f'match template: <{model_name}>')
model = MODELS.get(model_name)()
stop_words = _stop_words(model.stop_words, tokenizer)

Expand All @@ -94,27 +93,21 @@ def main(
else:
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt, nth_round == 1)
if step >= tm_model.session_len:
session_len = model.session_len
if session_len is None:
session_len = tm_model.session_len
if step >= session_len:
print('WARNING: exceed session max length.'
' Please end the session.')
continue

print(f'{prompt} ', end='', flush=True)
response_size = 0
sampling_param = SamplingParam(
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=False,
random_seed=seed,
stop_words=stop_words)
for outputs in generator.stream_infer(
session_id=session_id,
input_ids=input_ids,
request_output_len=512,
step=step,
sampling_param=sampling_param):
gen_config.random_seed = seed
gen_config.stop_words = stop_words
for outputs in generator.stream_infer(session_id=session_id,
input_ids=input_ids,
gen_config=gen_config):
status, res, tokens = outputs
# decode res
response = tokenizer.decode(res, offset=response_size)
Expand All @@ -134,6 +127,45 @@ def main(
nth_round += 1


def main(model_path,
model_name: str = None,
session_id: int = 1,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty: float = 1.0,
tp: int = 1,
stream_output: bool = True,
trust_remote_code=True):
"""An example to perform model inference through the command line
interface.
Args:
model_path (str): the huggingface model path
model_name (str): name of the model.
session_id (int): the identical id of a session
top_k (int): sampling top k.
top_p (int): sampling top p.
temperature (float): sampling temperature.
repetition_penalty (float): parameter to penalize repetition
tp (int): GPU number used in tensor parallelism
stream_output (bool): indicator for streaming output or not
trust_remote_code (bool): Trust remote code.
"""
engine_config = EngineConfig(model_name=model_name, tp=tp)
gen_config = EngineGenerationConfig(max_new_tokens=512,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=False)
return run_chat(model_path,
engine_config,
gen_config,
session_id=session_id,
trust_remote_code=trust_remote_code)


if __name__ == '__main__':
import fire

Expand Down
34 changes: 32 additions & 2 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,44 @@
from dataclasses import dataclass, field


@dataclass
class EngineConfig:
"""PyTorch Engine Config.
Args:
model_name (str): name of the given model.
tp (int): Tensor Parallelism. default 1.
session_len (int): Max session length. Default None.
max_batch_size: (int): Max batch size. Default 128.
eviction_type (str): What action to perform when kv cache
is full, ['recompute', 'copy'], Default 'recompute'.
prefill_interval (int): Interval to perform prefill,
Default 16.
block_size (int): paging cache block size, default 64.
num_cpu_blocks (int): Num cpu blocks. If num is 0, cache
would be allocate according to current environment.
num_gpu_blocks (int): Num gpu blocks. If num is 0, cache
would be allocate according to current environment.
"""
model_name: str = ''
tp: int = 1
session_len: int = None
max_batch_size: int = 128
eviction_type: str = 'recompute'
prefill_interval: int = 16
block_size: int = 64
num_cpu_blocks: int = 0
num_gpu_blocks: int = 0


@dataclass
class SchedulerConfig:
"""Config of scheduler."""

max_batches: int
max_session_len: int
max_request_output_len: int
eviction_type: str = 'copy'
max_request_output_len: int = 512
eviction_type: str = 'recompute'
prefill_interval: int = 16


Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .engine import Engine
from .engine import Engine, EngineConfig

__all__ = ['Engine']
__all__ = ['Engine', 'EngineConfig']
Loading

0 comments on commit c658f9a

Please sign in to comment.