Skip to content

Commit

Permalink
Check-in turbomind engine config (#909)
Browse files Browse the repository at this point in the history
* add EngineConfig for turbomind

* add EngineGenerationConfig for turbomind

* add deprecated params warning

* update TurbomindModelConfig

* fix comments

* update prepare_inputs

* update TurbomindModelConfig

* update EngineConfig

* use defaut bad/stop words

* use defaut bad/stop words

* fix comments

* typo

* fix bad words

* add engine_config to turbomind.chat

* update EngineConfig

* rename generation_config -> gen_config

* update config
  • Loading branch information
irexyc authored Jan 8, 2024
1 parent c658f9a commit 051307a
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 201 deletions.
3 changes: 2 additions & 1 deletion lmdeploy/turbomind/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .engine_config import EngineConfig
from .turbomind import TurboMind

__all__ = ['TurboMind']
__all__ = ['TurboMind', 'EngineConfig']
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from lmdeploy.turbomind.utils import get_gen_param

from .engine_config import EngineConfig

os.environ['TM_LOG_LEVEL'] = 'ERROR'


Expand Down Expand Up @@ -36,6 +38,7 @@ def main(model_path,
tp: int = 1,
stream_output: bool = True,
request_output_len: int = 512,
engine_config: EngineConfig = None,
**kwargs):
"""An example to perform model inference through the command line
interface.
Expand All @@ -51,6 +54,7 @@ def main(model_path,
"""
from lmdeploy import turbomind as tm
tm_model = tm.TurboMind.from_pretrained(model_path,
engine_config=engine_config,
model_name=model_name,
tp=tp,
capability=cap,
Expand Down
52 changes: 39 additions & 13 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import configparser
import copy
import inspect
import io
import json
import os.path as osp
from abc import ABC, abstractmethod
from dataclasses import dataclass
from configparser import ConfigParser

import torch
import tqdm
from mmengine import Registry
from pydantic.dataclasses import dataclass

from lmdeploy.model import MODELS
from lmdeploy.turbomind import EngineConfig

from ..source_model.base import BaseInputModel, BaseReader

Expand All @@ -30,18 +35,18 @@ def tprint(*args, **kwargs):
@dataclass
class TurbomindModelConfig:
"""Config for turbomind model."""
model_name: str
tensor_para_size: int
head_num: int
kv_head_num: int
vocab_size: int
num_layer: int
inter_size: int
norm_eps: float
attn_bias: int
start_id: int
end_id: int
session_len: int
model_name: str = None
tensor_para_size: int = None
head_num: int = None
kv_head_num: int = None
vocab_size: int = None
num_layer: int = None
inter_size: int = None
norm_eps: float = None
attn_bias: int = None
start_id: int = None
end_id: int = None
session_len: int = None
weight_type: str = 'fp16'
rotary_embedding: int = 128
rope_theta: float = 10000.0
Expand Down Expand Up @@ -77,6 +82,27 @@ def from_dict(cls, env, allow_none=False):
default.update(used)
return cls(**default)

@classmethod
def from_engine_config(cls, config: EngineConfig):
env = copy.deepcopy(config.__dict__)
env['tensor_para_size'] = env['tp']
ret = TurbomindModelConfig.from_dict(env, allow_none=True)
ret.rotary_embedding = ret.size_per_head
return ret

def toini(self):
config = copy.deepcopy(self.__dict__)
parser = ConfigParser()
parser['llama'] = config
with io.StringIO() as ss:
parser.write(ss)
ss.seek(0)
ini = ss.read()
return ini

def __str__(self):
return json.dumps(self.__dict__, indent=2)

@property
def valid(self):
"""Check if cfg is valid."""
Expand Down
47 changes: 47 additions & 0 deletions lmdeploy/turbomind/engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pydantic.dataclasses import dataclass


@dataclass
class EngineConfig:
"""TurboMind Engine config.
Args:​
model_name (str): the name of the deployed model​
model_format (str): the layout of the deployed model. It can be one of the following values [hf, llama, awq], `hf` meaning `hf_llama`, `llama` meaning `meta_llama`, `awq` meaning the quantized model by AWQ.​
group_size (int): the group size used when quantizing weights to 4bit, default to 128​
tp (int): the number of GPU cards used in tensor parallelism, default to 1​
session_len (int): the max session length of a sequence, default to None​
max_batch_size (int): the max batch size during inference, default to 128​
max_context_token_num (int): the max number of tokens to be processed in each forward pass, default to 1​
cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache, default to 0.5​
cache_block_seq_len (int): the length of a sequence in a k/v block, default to 128​
cache_chunk_size (int): the number of blocks each time TurboMind engine tries to realloc from gpu memory, default to -1. When it is -1, ​
num_tokens_per_iter (int): number of tokens to be processed per iteration, default to 0
max_prefill_iters (int): max prefill iters for a single request, default to 1
use_context_fmha (int): whether or not to use fmha in context decoding, default to 1​
quant_policy: (int): , default to 0. When k/v is quantized into 8 bit, set it to 4​
rope_scaling_factor (int): scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention​
use_dynamic_ntk (bool): whether or not to use dynamic ntk, default to False​
use_logn_attn (bool): whether or not to use log attn: default to False​
kv_bits (int): the number of bits of k/v after quantization, default to 8
""" # noqa: E501

model_name: str = None
model_format: str = None
tp: int = 1
session_len: int = None
max_batch_size: int = 128
group_size: int = 128
kv_bits: int = 8
max_context_token_num: int = 1
cache_max_entry_count: float = 0.5
cache_block_seq_len: int = 128
cache_chunk_size: int = -1
num_tokens_per_iter: int = 0
max_prefill_iters: int = 1
use_context_fmha: int = 1
quant_policy: int = 0
rope_scaling_factor: float = 0.0
use_dynamic_ntk: bool = False
use_logn_attn: bool = False
Loading

0 comments on commit 051307a

Please sign in to comment.