diff --git a/llm/fastdeploy_llm/config.py b/llm/fastdeploy_llm/config.py index 4843de927e..72c88dbf92 100644 --- a/llm/fastdeploy_llm/config.py +++ b/llm/fastdeploy_llm/config.py @@ -100,6 +100,7 @@ def __init__(self, model_dir, decode_strategy="sampling", mp_num=None): self.model_prompt_dir_path = config.get("prompt_dir_path", "./prompt_embedding") self.max_prefix_len = config.get("max_prefix_len", 128) + self.inference_response_timeout = 20 # timeout for inference engine output every token def is_arch(self, arch): return arch in self.architecture @@ -158,3 +159,10 @@ def load_environment_variables(self): "detect environment DISABLE_DYNAMIC_BATCHING={}, will reset `disable_dynamic_batching` to {}!". format(self.disable_dynamic_batching, self.disable_dynamic_batching)) + + if os.getenv("INFERENCE_RESPONSE_TIMEOUT", None): + self.inference_response_timeout = int(os.getenv("INFERENCE_RESPONSE_TIMEOUT")) + logger.warning( + "detect environment INFERENCE_RESPONSE_TIMEOUT={}, will reset `inference_response_timeout` to {}!". + format(self.inference_response_timeout, + self.inference_response_timeout)) \ No newline at end of file diff --git a/llm/fastdeploy_llm/engine.py b/llm/fastdeploy_llm/engine.py index 5de7e350f4..a21ccfe222 100644 --- a/llm/fastdeploy_llm/engine.py +++ b/llm/fastdeploy_llm/engine.py @@ -63,6 +63,11 @@ def parse_args(): type=int, default=64, help='num_attention_heads') + parser.add_argument( + '--num_key_value_heads', + type=int, + default=4, + help='num_key_value_heads') parser.add_argument( '--hidden_size', type=int, default=8192, help='hidden_size') parser.add_argument( @@ -140,15 +145,28 @@ def init_dist_env(world_size, seed=20): cache_kvs = [] for _ in range(args.num_layers): - cache_kvs.append( - paddle.cast( - paddle.to_tensor( - np.zeros( - (2, args.batch_size, args.num_attention_heads // nranks, + if 'llama' in args.architecture: + ## llama in PaddleNLP after https://github.com/PaddlePaddle/PaddleNLP/pull/7516/files changed cache kv shape + cache_kvs.append( + paddle.cast( + paddle.to_tensor( + np.zeros( + (2, args.batch_size, args.num_key_value_heads, args.max_seq_len + args.max_dec_len, args.hidden_size // args.num_attention_heads), - dtype='float32')), - args.dtype)) + dtype='float32')), + args.dtype)) + + else: + cache_kvs.append( + paddle.cast( + paddle.to_tensor( + np.zeros( + (2, args.batch_size, args.num_attention_heads // nranks, + args.max_seq_len + args.max_dec_len, args.hidden_size // + args.num_attention_heads), + dtype='float32')), + args.dtype)) pre_ids = paddle.to_tensor(np.full((args.batch_size, 2048), -1, dtype='int64')) tgt_generation_mask = paddle.zeros( diff --git a/llm/fastdeploy_llm/model.py b/llm/fastdeploy_llm/model.py index 37c76a0160..f581a05f44 100644 --- a/llm/fastdeploy_llm/model.py +++ b/llm/fastdeploy_llm/model.py @@ -25,7 +25,8 @@ from fastdeploy_llm.utils.utils import deserialize_from_file, get_files, remove_files from fastdeploy_llm.config import Config from fastdeploy_llm.task import Task, TaskStatus -from fastdeploy_llm.utils.logging_util import logger +from fastdeploy_llm.utils.logging_util import logger, warning_logger +from fastdeploy_llm.utils.logging_util import error_format, ErrorCode, ErrorType from concurrent.futures import ThreadPoolExecutor @@ -229,6 +230,7 @@ def async_predict(self, batch_tasks, stop_num=None): def _update_task_results(self, tasks): step_index = 1 + last_response_time = time.time() while True: filepath = f"./real_time_save.temp_ids_rank_0_step_{step_index}" if os.path.exists(filepath): @@ -240,6 +242,7 @@ def _update_task_results(self, tasks): except: fin.close() token_ids = deserialize_from_file(fin) + last_response_time = time.time() fin.close() step_index += 1 for b, token_id in enumerate(token_ids): @@ -283,6 +286,12 @@ def _update_task_results(self, tasks): else: if not self._is_engine_busy(): break + if time.time() - last_response_time > self.config.inference_response_timeout: + error_type = ErrorType.Server + error_code = ErrorCode.S0003 + error_info = "Inference engine output token timeout due to some unexpectable exceptions." + error_msg = error_format.format(error_type.name, error_code.name, error_info) + warning_logger.error(error_msg) ret = self.engine_proc.poll() if ret is not None: logger.error( diff --git a/llm/fastdeploy_llm/utils/logging_util.py b/llm/fastdeploy_llm/utils/logging_util.py index 55a316dfbe..6803ed0c72 100644 --- a/llm/fastdeploy_llm/utils/logging_util.py +++ b/llm/fastdeploy_llm/utils/logging_util.py @@ -54,6 +54,7 @@ class ErrorCode(Enum): S0000 = 2 # 服务负载过大 S0001 = 3 # 服务没能正常启动 S0002 = 4 # 服务退出 + S0003 = 5 # 服务异常,推理引擎吐出结果超时 class ErrorType(Enum): Query = 0 # Query错误