From 051307ae406c27b61fc6f1b7fb1386b536cdbf88 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Mon, 8 Jan 2024 11:53:22 +0800 Subject: [PATCH] Check-in turbomind engine config (#909) * add EngineConfig for turbomind * add EngineGenerationConfig for turbomind * add deprecated params warning * update TurbomindModelConfig * fix comments * update prepare_inputs * update TurbomindModelConfig * update EngineConfig * use defaut bad/stop words * use defaut bad/stop words * fix comments * typo * fix bad words * add engine_config to turbomind.chat * update EngineConfig * rename generation_config -> gen_config * update config --- lmdeploy/turbomind/__init__.py | 3 +- lmdeploy/turbomind/chat.py | 4 + .../turbomind/deploy/target_model/base.py | 52 ++- lmdeploy/turbomind/engine_config.py | 47 +++ lmdeploy/turbomind/turbomind.py | 358 +++++++++--------- .../triton_backend/llama/LlamaTritonModel.cc | 22 +- 6 files changed, 285 insertions(+), 201 deletions(-) create mode 100644 lmdeploy/turbomind/engine_config.py diff --git a/lmdeploy/turbomind/__init__.py b/lmdeploy/turbomind/__init__.py index b2df77014c..655469a17c 100644 --- a/lmdeploy/turbomind/__init__.py +++ b/lmdeploy/turbomind/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .engine_config import EngineConfig from .turbomind import TurboMind -__all__ = ['TurboMind'] +__all__ = ['TurboMind', 'EngineConfig'] diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py index 2576b93a90..7d9149245a 100644 --- a/lmdeploy/turbomind/chat.py +++ b/lmdeploy/turbomind/chat.py @@ -5,6 +5,8 @@ from lmdeploy.turbomind.utils import get_gen_param +from .engine_config import EngineConfig + os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -36,6 +38,7 @@ def main(model_path, tp: int = 1, stream_output: bool = True, request_output_len: int = 512, + engine_config: EngineConfig = None, **kwargs): """An example to perform model inference through the command line interface. @@ -51,6 +54,7 @@ def main(model_path, """ from lmdeploy import turbomind as tm tm_model = tm.TurboMind.from_pretrained(model_path, + engine_config=engine_config, model_name=model_name, tp=tp, capability=cap, diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index d2063d0048..726e4a2abf 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -1,15 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import configparser +import copy import inspect +import io +import json import os.path as osp from abc import ABC, abstractmethod -from dataclasses import dataclass +from configparser import ConfigParser import torch import tqdm from mmengine import Registry +from pydantic.dataclasses import dataclass from lmdeploy.model import MODELS +from lmdeploy.turbomind import EngineConfig from ..source_model.base import BaseInputModel, BaseReader @@ -30,18 +35,18 @@ def tprint(*args, **kwargs): @dataclass class TurbomindModelConfig: """Config for turbomind model.""" - model_name: str - tensor_para_size: int - head_num: int - kv_head_num: int - vocab_size: int - num_layer: int - inter_size: int - norm_eps: float - attn_bias: int - start_id: int - end_id: int - session_len: int + model_name: str = None + tensor_para_size: int = None + head_num: int = None + kv_head_num: int = None + vocab_size: int = None + num_layer: int = None + inter_size: int = None + norm_eps: float = None + attn_bias: int = None + start_id: int = None + end_id: int = None + session_len: int = None weight_type: str = 'fp16' rotary_embedding: int = 128 rope_theta: float = 10000.0 @@ -77,6 +82,27 @@ def from_dict(cls, env, allow_none=False): default.update(used) return cls(**default) + @classmethod + def from_engine_config(cls, config: EngineConfig): + env = copy.deepcopy(config.__dict__) + env['tensor_para_size'] = env['tp'] + ret = TurbomindModelConfig.from_dict(env, allow_none=True) + ret.rotary_embedding = ret.size_per_head + return ret + + def toini(self): + config = copy.deepcopy(self.__dict__) + parser = ConfigParser() + parser['llama'] = config + with io.StringIO() as ss: + parser.write(ss) + ss.seek(0) + ini = ss.read() + return ini + + def __str__(self): + return json.dumps(self.__dict__, indent=2) + @property def valid(self): """Check if cfg is valid.""" diff --git a/lmdeploy/turbomind/engine_config.py b/lmdeploy/turbomind/engine_config.py new file mode 100644 index 0000000000..752d002703 --- /dev/null +++ b/lmdeploy/turbomind/engine_config.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from pydantic.dataclasses import dataclass + + +@dataclass +class EngineConfig: + """TurboMind Engine config. + + Args:​ + model_name (str): the name of the deployed model​ + model_format (str): the layout of the deployed model. It can be one of the following values [hf, llama, awq], `hf` meaning `hf_llama`, `llama` meaning `meta_llama`, `awq` meaning the quantized model by AWQ.​ + group_size (int): the group size used when quantizing weights to 4bit, default to 128​ + tp (int): the number of GPU cards used in tensor parallelism, default to 1​ + session_len (int): the max session length of a sequence, default to None​ + max_batch_size (int): the max batch size during inference, default to 128​ + max_context_token_num (int): the max number of tokens to be processed in each forward pass, default to 1​ + cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache, default to 0.5​ + cache_block_seq_len (int): the length of a sequence in a k/v block, default to 128​ + cache_chunk_size (int): the number of blocks each time TurboMind engine tries to realloc from gpu memory, default to -1. When it is -1, ​ + num_tokens_per_iter (int): number of tokens to be processed per iteration, default to 0 + max_prefill_iters (int): max prefill iters for a single request, default to 1 + use_context_fmha (int): whether or not to use fmha in context decoding, default to 1​ + quant_policy: (int): , default to 0. When k/v is quantized into 8 bit, set it to 4​ + rope_scaling_factor (int): scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention​ + use_dynamic_ntk (bool): whether or not to use dynamic ntk, default to False​ + use_logn_attn (bool): whether or not to use log attn: default to False​ + kv_bits (int): the number of bits of k/v after quantization, default to 8 + """ # noqa: E501 + + model_name: str = None + model_format: str = None + tp: int = 1 + session_len: int = None + max_batch_size: int = 128 + group_size: int = 128 + kv_bits: int = 8 + max_context_token_num: int = 1 + cache_max_entry_count: float = 0.5 + cache_block_seq_len: int = 128 + cache_chunk_size: int = -1 + num_tokens_per_iter: int = 0 + max_prefill_iters: int = 1 + use_context_fmha: int = 1 + quant_policy: int = 0 + rope_scaling_factor: float = 0.0 + use_dynamic_ntk: bool = False + use_logn_attn: bool = False diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 353c7b741a..74686efe7d 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -1,8 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import copy -import io -import json import logging import os.path as osp import sys @@ -18,6 +16,7 @@ from torch.nn.utils.rnn import pad_sequence import lmdeploy +from lmdeploy.messages import EngineGenerationConfig from lmdeploy.model import MODELS, BaseModel, best_match_model from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger @@ -26,8 +25,9 @@ update_config_weight_type, update_output_format) from .deploy.source_model.base import INPUT_MODELS from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig +from .engine_config import EngineConfig from .utils import (ModelSource, check_tm_model_input, create_hf_download_args, - get_hf_config_content, get_model_source) + get_model_source) # TODO: find another way import _turbomind lmdeploy_dir = osp.split(lmdeploy.__file__)[0] @@ -57,6 +57,14 @@ def _stop_words(stop_words: List[str], tokenizer: Tokenizer): return stop_words +def _construct_stop_or_bad_words(words: List[int] = None): + if words is None or len(words) == 0: + return None + offsets = range(1, len(words) + 1) + combined = np.array([[words, offsets]]).astype(np.int32) + return combined + + def _np_dict_to_tm_dict(np_dict: dict): """map numpy.ndarray to turbomind's tensor.""" ret = _tm.TensorMap() @@ -77,6 +85,28 @@ def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap): return ret +def _update_engine_config(config: EngineConfig, **kwargs): + if config is None: + config = EngineConfig() + for k, v in kwargs.items(): + if v and hasattr(config, k): + setattr(config, k, v) + get_logger('turbomind').warning( + f'kwargs {k} is deprecated to initialize model, ' + 'use EngineConfig instead.') + return config + + +def _update_tm_config(dst: TurbomindModelConfig, src: EngineConfig): + dst_dict = copy.deepcopy(dst.__dict__) + src_dict = copy.deepcopy(src.__dict__) + src_dict['tensor_para_size'] = src_dict['tp'] + for k, v in src_dict.items(): + if v is not None and k in dst_dict: + dst_dict[k] = v + return TurbomindModelConfig.from_dict(dst_dict) + + @contextmanager def cuda_ctx(device_id): old_device = torch.cuda.current_device() @@ -102,34 +132,39 @@ class TurboMind: def __init__(self, model_path: str, + engine_config: EngineConfig = None, model_source: ModelSource = ModelSource.WORKSPACE, model_name: Optional[str] = None, model_format: Optional[str] = None, group_size: Optional[int] = None, tp: Optional[int] = None, **kwargs): - if tp is not None: - assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n' - self.gpu_count = tp if tp is not None else 1 + + engine_config = _update_engine_config(engine_config, + model_name=model_name, + model_format=model_format, + group_size=group_size, + tp=tp, + **kwargs) + tp = engine_config.tp + assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n' + self.gpu_count = tp if model_source == ModelSource.WORKSPACE: tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') self.tokenizer = Tokenizer(tokenizer_model_path) - self.model_comm = self._from_workspace(model_path) + self.model_comm = self._from_workspace(model_path=model_path, + engine_config=engine_config) else: self.tokenizer = Tokenizer(model_path) self.model_comm = self._from_hf(model_source=model_source, model_path=model_path, - model_name=model_name, - model_format=model_format, - group_size=group_size, - tp=tp, - **kwargs) + engine_config=engine_config) - self.eos_id = self.tokenizer.eos_token_id self.model: BaseModel = MODELS.get(self.model_name)(**kwargs) - self.session_len = self.model.session_len + self.session_len = self.config.session_len + self.eos_id = self.tokenizer.eos_token_id self.stop_words = _stop_words(self.model.stop_words, self.tokenizer) def _create_weight(self, model_comm): @@ -194,88 +229,59 @@ def _get_params(device_id, que): tm_params[k] = [] tm_params[k].append(v) - def _from_hf(self, - model_source: ModelSource, - model_path: str, - model_name: Optional[str] = None, - model_format: Optional[str] = None, - group_size: Optional[int] = None, - tp: Optional[int] = None, - **kwargs): + def _from_hf(self, model_source: ModelSource, model_path: str, + engine_config: EngineConfig): """Load model which is in hf format.""" - # get model_name, group_size if is lmdeploy managed. - if model_source == ModelSource.HF_LMDEPLOY: - config = get_hf_config_content(model_path, local_files_only=True) - tm_config = config['turbomind'] - tm_config.update(kwargs) - var_shoud_be_none = dict(model_name=model_name, - model_format=model_format, - group_size=group_size) - for key, value in var_shoud_be_none.items(): - assert value is None, f'{key} should be None when model is '\ - f'from {model_source}' - model_name = tm_config['model_name'] - group_size = tm_config['group_size'] - if tm_config['weight_type'] == 'int4': - model_format = 'awq' - else: - assert model_name is not None, 'please supply model_name when ' \ - f'model is form {model_source}' - if osp.exists(osp.join(model_path, 'outputs_stats.pth')): - model_format = 'awq' if model_format is None else model_format - group_size = 128 if group_size is None else group_size - tm_config = kwargs - - assert model_name in MODELS.module_dict.keys(), \ - f"'{model_name}' is not supported. " \ + assert model_source == ModelSource.HF_MODEL, \ + f'{model_source} is not supported' + assert engine_config.model_name in MODELS.module_dict.keys(), \ + f"'{engine_config.model_name}' is not supported. " \ f'The supported models are: {MODELS.module_dict.keys()}' - assert model_format in supported_formats, 'the model format ' \ - f'should be in {supported_formats}' + assert engine_config.model_format in supported_formats, \ + f'The model format should be in {supported_formats}' + + # update model_format if not supplied and outputs_stats.pth exists + if osp.exists(osp.join(model_path, 'outputs_stats.pth')) and \ + engine_config.model_format is None: + engine_config.model_format = 'awq' data_type = 'fp16' output_format = 'fp16' - inferred_model_format = get_model_format(model_name, model_format) - cfg = TurbomindModelConfig.from_dict(tm_config, allow_none=True) - - # overwrite with input params - cfg.model_name = model_name - cfg.tensor_para_size = 1 if tp is None else tp - cfg.rotary_embedding = cfg.size_per_head - cfg.group_size = group_size + inferred_model_format = get_model_format(engine_config.model_name, + engine_config.model_format) + cfg = TurbomindModelConfig.from_engine_config(engine_config) if inferred_model_format.find('awq') != -1: cfg.weight_type = 'int4' output_format = 'w4' data_type = 'int4' - assert group_size > 0, f'group_size: {group_size} should > 0' + assert cfg.group_size > 0, \ + f'group_size: {cfg.group_size} should > 0' else: - output_format = update_output_format(model_name, + output_format = update_output_format(engine_config.model_name, inferred_model_format, model_path, output_format) data_type = output_format update_config_weight_type(output_format, cfg) - self.config = cfg - self.model_name = model_name - self.data_type = data_type - input_model = INPUT_MODELS.get(inferred_model_format)( model_path=model_path, tokenizer_path=model_path, ckpt_path=None) output_model = OUTPUT_MODELS.get(output_format)( input_model=input_model, cfg=cfg, to_file=False, out_dir='') - config = copy.deepcopy(output_model.cfg.__dict__) - logger.warning(f'model_config:\n{json.dumps(config, indent=2)}') - parser = ConfigParser() - parser['llama'] = config - with io.StringIO() as ss: - parser.write(ss) - ss.seek(0) - config = ss.read() + cfg = output_model.cfg + if engine_config.session_len is not None: + cfg.session_len = engine_config.session_len + + self.config = cfg + self.model_name = engine_config.model_name + self.data_type = data_type + + logger.warning(f'model_config:\n\n{cfg.toini()}') model_comm = _tm.AbstractTransformerModel.create_llama_model( model_dir='', - config=config, + config=cfg.toini(), tensor_para_size=self.gpu_count, data_type=data_type) @@ -289,35 +295,50 @@ def _from_hf(self, output_model.export() # load kv qparams - self._load_kv_qparams(model_path, tm_params, **kwargs) + self._load_kv_qparams(model_path, + tm_params, + kv_sym=False, + kv_bits=engine_config.kv_bits) assert len(tm_params) == 0, f'missing {tm_params.keys()}' return model_comm - def _from_workspace(self, model_path: str): + def _from_workspace(self, model_path: str, engine_config: EngineConfig): """Load model which is converted by `lmdeploy convert`""" ini_path = osp.join(model_path, 'triton_models', 'weights', 'config.ini') + # load cfg with open(ini_path, 'r') as f: parser = ConfigParser() parser.read_file(f) - section_name = 'llama' - tp_cfg = parser.getint(section_name, 'tensor_para_size') - - if tp_cfg != 1 and tp_cfg != self.gpu_count: - get_logger('turbomind').info( - f'found tp={tp_cfg} in config.ini.') - self.gpu_count = tp_cfg - self.model_name = parser.get(section_name, 'model_name') - self.data_type = parser.get(section_name, 'weight_type') - cfg = parser._sections[section_name] - cfg = TurbomindModelConfig.from_dict(cfg) - self.config = cfg + section_name = 'llama' + _cfg = parser._sections[section_name] + cfg = TurbomindModelConfig.from_dict(_cfg) + + # check whether input tp is valid + if cfg.tensor_para_size != 1 and \ + engine_config.tp != cfg.tensor_para_size: + get_logger('turbomind').info( + f'found tp={cfg.tensor_para_size} in config.ini.') + self.gpu_count = cfg.tensor_para_size + engine_config.tp = cfg.tensor_para_size + + # update cfg + cfg = _update_tm_config(cfg, engine_config) + if engine_config.session_len is not None: + cfg.session_len = engine_config.session_len + + # update cls + self.config = cfg + self.model_name = cfg.model_name + self.data_type = cfg.weight_type # create model + logger.warning(f'model_config:\n\n{cfg.toini()}') weight_dir = osp.join(model_path, 'triton_models', 'weights') model_comm = _tm.AbstractTransformerModel.create_llama_model( - weight_dir, + model_dir=weight_dir, + config=cfg.toini(), tensor_para_size=self.gpu_count, data_type=self.data_type) @@ -328,6 +349,7 @@ def _from_workspace(self, model_path: str): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, + engine_config: EngineConfig = None, model_name: Optional[str] = None, model_format: Optional[str] = None, group_size: Optional[int] = None, @@ -364,7 +386,8 @@ def from_pretrained(cls, logger.warning(f'Please input a model_name for {model_source}') else: model_name = potential_names[0] - logger.warning(f'model_name: {model_name}') + logger.warning( + f'Best matched chat template name: {model_name}') if model_source == ModelSource.WORKSPACE: local_path = pretrained_model_name_or_path else: @@ -379,8 +402,9 @@ def from_pretrained(cls, local_path = pretrained_model_name_or_path logger.warning(f'model_source: {model_source}') - return cls(model_source=model_source, - model_path=local_path, + return cls(model_path=local_path, + engine_config=engine_config, + model_source=model_source, model_name=model_name, model_format=model_format, group_size=group_size, @@ -414,8 +438,6 @@ def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0): self.gpu_count = tm_model.gpu_count self.stop_words = tm_model.stop_words - self.stop_tokens = [] if self.stop_words is None else \ - self.stop_words.flatten().tolist() self.eos_id = tm_model.eos_id self.session_len = tm_model.session_len @@ -464,23 +486,39 @@ def _func(device_id, enque_output): t.start() self.threads[device_id] = t + def _update_generation_config(self, config: EngineGenerationConfig, + **kwargs: dict): + if config is None: + config = EngineGenerationConfig() + # backward compatibility + # if doesn't supply stop words, use default + if config.stop_words is None and self.stop_words is not None: + config.stop_words = self.stop_words[0][0].tolist() + + deprecated_kwargs = [] + for k, v in kwargs.items(): + if k in config.__dict__: + config.__dict__[k] = v + deprecated_kwargs.append(k) + if kwargs.get('request_output_len'): + config.max_new_tokens = kwargs['request_output_len'] + deprecated_kwargs.append('request_output_len') + for k in deprecated_kwargs: + get_logger('turbomind').warning( + f'kwargs {k} is deprecated for inference, ' + 'use GenerationConfig instead.') + return config + def prepare_inputs(self, session_id, input_ids, + gen_config: EngineGenerationConfig, input_embeddings=None, input_embedding_ranges=None, - request_output_len: int = 512, sequence_start: bool = True, sequence_end: bool = False, step=0, - stop=False, - top_p=0.8, - top_k=40, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=False, - random_seed=None, - stream_output=False): + stop=False): """Convert inputs format.""" if len(input_ids) == 0: input_ids = [[]] @@ -512,19 +550,16 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): input_ids=input_ids, input_lengths=input_lengths, request_output_len=np.full(input_lengths.shape, - request_output_len, + gen_config.max_new_tokens, dtype=np.uint32), - runtime_top_k=_broadcast_np(top_k, np.uint32), - runtime_top_p=_broadcast_np(top_p, np.float32), - temperature=_broadcast_np(temperature, np.float32), - repetition_penalty=_broadcast_np(repetition_penalty, np.float32), + runtime_top_k=_broadcast_np(gen_config.top_k, np.uint32), + runtime_top_p=_broadcast_np(gen_config.top_p, np.float32), + temperature=_broadcast_np(gen_config.temperature, np.float32), + repetition_penalty=_broadcast_np(gen_config.repetition_penalty, + np.float32), step=step, # session input - session_len=self.session_len * - np.ones([ - batch_size, - ], dtype=np.uint32), START=_broadcast_np((1 if sequence_start else 0), np.int32), END=_broadcast_np((1 if sequence_end else 0), np.int32), CORRID=np.array(session_id, dtype=np.uint64), @@ -568,20 +603,25 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): inputs['input_embeddings'] = input_embeddings inputs['input_embedding_ranges'] = input_embedding_ranges - if ignore_eos: + bad_words = [] + if gen_config.bad_words is not None: + bad_words.extend(gen_config.bad_words) + if gen_config.ignore_eos: stop_words = None - bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32) + bad_words.append(self.eos_id) else: - stop_words = self.stop_words - bad_words = None + stop_words = gen_config.stop_words + stop_words = _construct_stop_or_bad_words(stop_words) + bad_words = _construct_stop_or_bad_words(bad_words) if stop_words is not None: inputs['stop_words_list'] = stop_words if bad_words is not None: inputs['bad_words_list'] = bad_words - if random_seed is not None: - inputs['random_seed'] = _broadcast_np(random_seed, np.uint64) + if gen_config.random_seed is not None: + inputs['random_seed'] = _broadcast_np(gen_config.random_seed, + np.uint64) return inputs, input_lengths async def async_stream_infer(self, @@ -589,18 +629,13 @@ async def async_stream_infer(self, input_ids, input_embeddings=None, input_embedding_ranges=None, - request_output_len: int = 512, sequence_start: bool = True, sequence_end: bool = False, step=0, stop=False, - top_p=0.8, - top_k=40, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=False, - random_seed=None, - stream_output=False): + gen_config: EngineGenerationConfig = None, + stream_output=False, + **kwargs): """Perform model inference. Args: @@ -609,42 +644,28 @@ async def async_stream_infer(self, input_embeddings (List[numpy.ndarray]): embeddings features input_embedding_ranges (List[Tuple[int,int]]): the begin/end offsets of input_embeddings to input_ids - request_output_len (int): the max number of to-be-generated tokens sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence step (int): the offset of the k/v cache stop (bool): indicator for cancelling the session - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. - top_k (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - temperature (float): to modulate the next token probability - repetition_penalty (float): The parameter for repetition penalty. - 1.0 means no penalty - ignore_eos (bool): indicator for ignoring eos - random_seed (int): seed used by sampling + gen_config (EngineGenerationConfig): generation config stream_output (bool): indicator for stream output + kwargs (dict): kwargs for backward compatibility """ if stream_output and not stop: self.model_insts[0].register_callback(self._forward_callback) + + gen_config = self._update_generation_config(gen_config, **kwargs) inputs, input_lengths = self.prepare_inputs( session_id=session_id, input_ids=input_ids, input_embeddings=input_embeddings, input_embedding_ranges=input_embedding_ranges, - request_output_len=request_output_len, sequence_start=sequence_start, sequence_end=sequence_end, step=step, stop=stop, - top_p=top_p, - top_k=top_k, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=ignore_eos, - random_seed=random_seed, - stream_output=stream_output) + gen_config=gen_config) tm_inputs = _np_dict_to_tm_dict(inputs) # start forward thread @@ -678,10 +699,12 @@ async def async_stream_infer(self, outputs = [] for output, len_ in zip(output_ids, sequence_length): output, len_ = output, len_.item() - if len(output) > 0 and output[-1].item( - ) == self.eos_id and not ignore_eos: + if len(output) > 0 and output[-1].item() == self.eos_id \ + and not gen_config.ignore_eos: outputs.append((output[:-1], len_ - 1)) - elif len(output) > 0 and output[-1].item() in self.stop_tokens: + elif len(output) > 0 and \ + gen_config.stop_words is not None and \ + output[-1].item() in gen_config.stop_words: outputs.append((output[:-1], len_)) else: outputs.append((output, len_)) @@ -702,18 +725,13 @@ def stream_infer(self, input_ids, input_embeddings=None, input_embedding_ranges=None, - request_output_len: int = 512, sequence_start: bool = True, sequence_end: bool = False, step=0, stop=False, - top_p=0.8, - top_k=40, - temperature=0.8, - repetition_penalty=1.0, - ignore_eos=False, - random_seed=None, - stream_output=False): + gen_config: EngineGenerationConfig = None, + stream_output=False, + **kwargs): """Perform model inference. Args: @@ -722,42 +740,28 @@ def stream_infer(self, input_embeddings (List[numpy.ndarray]): embeddings features input_embedding_ranges (List[Tuple[int,int]]): the begin/end offsets of input_embeddings to input_ids - request_output_len (int): the max number of to-be-generated tokens sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence step (int): the offset of the k/v cache stop (bool): indicator for cancelling the session - top_p (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. - top_k (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - temperature (float): to modulate the next token probability - repetition_penalty (float): The parameter for repetition penalty. - 1.0 means no penalty - ignore_eos (bool): indicator for ignoring eos - random_seed (int): seed used by sampling + gen_config (EngineGenerationConfig): generation config stream_output (bool): indicator for stream output + kwargs (dict): kwargs for backward compatibility """ if stream_output and not stop: self.model_insts[0].register_callback(self._forward_callback) + + gen_config = self._update_generation_config(gen_config, **kwargs) inputs, input_lengths = self.prepare_inputs( session_id=session_id, input_ids=input_ids, input_embeddings=input_embeddings, input_embedding_ranges=input_embedding_ranges, - request_output_len=request_output_len, sequence_start=sequence_start, sequence_end=sequence_end, step=step, stop=stop, - top_p=top_p, - top_k=top_k, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=ignore_eos, - random_seed=random_seed, - stream_output=stream_output) + gen_config=gen_config) tm_inputs = _np_dict_to_tm_dict(inputs) # start forward thread @@ -786,10 +790,12 @@ def stream_infer(self, outputs = [] for output, len_ in zip(output_ids, sequence_length): output, len_ = output, len_.item() - if len(output) > 0 and output[-1].item( - ) == self.eos_id and not ignore_eos: + if len(output) > 0 and output[-1].item() == self.eos_id \ + and not gen_config.ignore_eos: outputs.append((output[:-1], len_ - 1)) - elif len(output) > 0 and output[-1].item() in self.stop_tokens: + elif len(output) > 0 and \ + gen_config.stop_words is not None and \ + output[-1].item() in gen_config.stop_words: outputs.append((output[:-1], len_)) else: outputs.append((output, len_)) diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 6127fdf2db..77f6b19833 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -144,7 +144,17 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, enable_custom_all_reduce_(enable_custom_all_reduce) { INIReader reader; - FT_CHECK_WITH_INFO((config.empty() ^ model_dir.empty()), "invalid init options"); + FT_CHECK_WITH_INFO(!(config.empty() && model_dir.empty()), "invalid init options"); + + if (!model_dir.empty()) { + model_dir_ = model_dir; + const std::string inifile{model_dir + "/config.ini"}; + reader = INIReader(inifile); + if (reader.ParseError() < 0) { + TM_LOG_ERROR("[ERROR] Can't load %s", inifile.c_str()); + ft::FT_CHECK(false); + } + } if (!config.empty()) { std::FILE* tmpf = std::tmpfile(); @@ -157,16 +167,6 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, } } - if (!model_dir.empty()) { - model_dir_ = model_dir; - const std::string inifile{model_dir + "/config.ini"}; - reader = INIReader(inifile); - if (reader.ParseError() < 0) { - TM_LOG_ERROR("[ERROR] Can't load %s", inifile.c_str()); - ft::FT_CHECK(false); - } - } - model_name_ = reader.Get("llama", "model_name"); head_num_ = reader.GetInteger("llama", "head_num"); kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0);