Skip to content

Commit

Permalink
Add warning for server hangs
Browse files Browse the repository at this point in the history
  • Loading branch information
rainyfly committed Dec 25, 2023
1 parent 67ca253 commit 83cf2ac
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
8 changes: 8 additions & 0 deletions llm/fastdeploy_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, model_dir, decode_strategy="sampling", mp_num=None):
self.model_prompt_dir_path = config.get("prompt_dir_path",
"./prompt_embedding")
self.max_prefix_len = config.get("max_prefix_len", 128)
self.inference_response_timeout = 20 # timeout for inference engine output every token

def is_arch(self, arch):
return arch in self.architecture
Expand Down Expand Up @@ -158,3 +159,10 @@ def load_environment_variables(self):
"detect environment DISABLE_DYNAMIC_BATCHING={}, will reset `disable_dynamic_batching` to {}!".
format(self.disable_dynamic_batching,
self.disable_dynamic_batching))

if os.getenv("INFERENCE_RESPONSE_TIMEOUT", None):
self.inference_response_timeout = int(os.getenv("INFERENCE_RESPONSE_TIMEOUT"))
logger.warning(
"detect environment INFERENCE_RESPONSE_TIMEOUT={}, will reset `inference_response_timeout` to {}!".
format(self.inference_response_timeout,
self.inference_response_timeout))
32 changes: 25 additions & 7 deletions llm/fastdeploy_llm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def parse_args():
type=int,
default=64,
help='num_attention_heads')
parser.add_argument(
'--num_key_value_heads',
type=int,
default=4,
help='num_key_value_heads')
parser.add_argument(
'--hidden_size', type=int, default=8192, help='hidden_size')
parser.add_argument(
Expand Down Expand Up @@ -140,15 +145,28 @@ def init_dist_env(world_size, seed=20):

cache_kvs = []
for _ in range(args.num_layers):
cache_kvs.append(
paddle.cast(
paddle.to_tensor(
np.zeros(
(2, args.batch_size, args.num_attention_heads // nranks,
if 'llama' in args.architecture:
## llama in PaddleNLP after https://github.com/PaddlePaddle/PaddleNLP/pull/7516/files changed cache kv shape
cache_kvs.append(
paddle.cast(
paddle.to_tensor(
np.zeros(
(2, args.batch_size, args.num_key_value_heads,
args.max_seq_len + args.max_dec_len, args.hidden_size //
args.num_attention_heads),
dtype='float32')),
args.dtype))
dtype='float32')),
args.dtype))

else:
cache_kvs.append(
paddle.cast(
paddle.to_tensor(
np.zeros(
(2, args.batch_size, args.num_attention_heads // nranks,
args.max_seq_len + args.max_dec_len, args.hidden_size //
args.num_attention_heads),
dtype='float32')),
args.dtype))

pre_ids = paddle.to_tensor(np.full((args.batch_size, 2048), -1, dtype='int64'))
tgt_generation_mask = paddle.zeros(
Expand Down
11 changes: 10 additions & 1 deletion llm/fastdeploy_llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from fastdeploy_llm.utils.utils import deserialize_from_file, get_files, remove_files
from fastdeploy_llm.config import Config
from fastdeploy_llm.task import Task, TaskStatus
from fastdeploy_llm.utils.logging_util import logger
from fastdeploy_llm.utils.logging_util import logger, warning_logger
from fastdeploy_llm.utils.logging_util import error_format, ErrorCode, ErrorType

from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -229,6 +230,7 @@ def async_predict(self, batch_tasks, stop_num=None):

def _update_task_results(self, tasks):
step_index = 1
last_response_time = time.time()
while True:
filepath = f"./real_time_save.temp_ids_rank_0_step_{step_index}"
if os.path.exists(filepath):
Expand All @@ -240,6 +242,7 @@ def _update_task_results(self, tasks):
except:
fin.close()
token_ids = deserialize_from_file(fin)
last_response_time = time.time()
fin.close()
step_index += 1
for b, token_id in enumerate(token_ids):
Expand Down Expand Up @@ -283,6 +286,12 @@ def _update_task_results(self, tasks):
else:
if not self._is_engine_busy():
break
if time.time() - last_response_time > self.config.inference_response_timeout:
error_type = ErrorType.Server
error_code = ErrorCode.S0003
error_info = "Inference engine output token timeout due to some unexpectable exceptions."
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
ret = self.engine_proc.poll()
if ret is not None:
logger.error(
Expand Down
1 change: 1 addition & 0 deletions llm/fastdeploy_llm/utils/logging_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ErrorCode(Enum):
S0000 = 2 # 服务负载过大
S0001 = 3 # 服务没能正常启动
S0002 = 4 # 服务退出
S0003 = 5 # 服务异常,推理引擎吐出结果超时

class ErrorType(Enum):
Query = 0 # Query错误
Expand Down

0 comments on commit 83cf2ac

Please sign in to comment.