From 8254cc5233ed87a35c1468c013e202d856ab16bf Mon Sep 17 00:00:00 2001 From: juncaipeng <13006307475@163.com> Date: Wed, 25 Sep 2024 10:30:41 +0800 Subject: [PATCH] update --- llm/benchmark/benchmark.py | 3 +- llm/server/server/engine/config.py | 46 ++++++------- llm/server/server/engine/engine.py | 4 +- llm/server/server/engine/infer.py | 21 +++--- llm/server/server/triton_server.py | 104 ++++++++++++----------------- 5 files changed, 77 insertions(+), 101 deletions(-) diff --git a/llm/benchmark/benchmark.py b/llm/benchmark/benchmark.py index 549f688af6..587201f88f 100644 --- a/llm/benchmark/benchmark.py +++ b/llm/benchmark/benchmark.py @@ -254,8 +254,7 @@ def parse_args(): result_path = f"output/{test_tag}.jsonl" if os.path.exists(result_path): - logger.error(f"result file ({result_path}) already exists, exit") - exit() + logger.error(f"result file ({result_path}) already exists, overwrite it") if not os.path.exists("output/"): os.makedirs("output/") logger.info(f"result_path: {result_path}") diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 5c16abd542..f7ca557972 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -17,7 +17,9 @@ import sys from datetime import datetime +import paddle from paddlenlp.generation import GenerationConfig + from server.utils import model_server_logger @@ -37,17 +39,13 @@ def read_from_env(self): get the configuration from environment """ env = os.environ - self.model_dir = env.get( - "MODEL_DIR", "/opt/output/Serving/models") + + self.model_dir = env.get("MODEL_DIR", "/opt/output/Serving/models") if not self.model_dir: raise Exception("The parameter MODEL_DIR is None.") self.mp_num = int(env.get("MP_NUM", 8)) - self.config_json_file = env.get("CONFIG_JSON_FILE", "config.json") - self.model_config_path = os.path.join(self.model_dir, self.config_json_file) - if env.get("FD_MODEL_CONFIG_PATH", None): - self.model_config_path = env.get("FD_MODEL_CONFIG_PATH") - - # distributed config + self.model_config_path = os.path.join(self.model_dir, + env.get("CONFIG_JSON_FILE", "config.json")) self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv") if os.getenv("DISTRIBUTED_CONFIG", None): self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG") @@ -67,15 +65,16 @@ def read_from_env(self): raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0") self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0)) + # server ports + self.grpc_port = int(os.getenv("GRPC_PORT", 8000)) + self.http_port = int(os.getenv("HTTP_PORT", 8001)) + self.metrics_port = int(os.getenv("METRICS_PORT", 8002)) + self.infer_queue_port = int(os.getenv("INFER_QUEUE_PORT", 8005)) + # if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled + self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", -1)) + # max cached task num self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128")) - # if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled - self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1")) - if self.push_mode_http_port > 0: - grpc_port = os.getenv("GRPC_PORT", None) - if grpc_port is None: - raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0") - self.grpc_port = int(grpc_port) # http worker num self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1")) @@ -83,12 +82,8 @@ def read_from_env(self): raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive") # Padlle commit id - import paddle self.paddle_commit_id = paddle.version.commit - # time interval for detecting whether the engine loop is normal during probing - self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10)) - # model config self.dtype = env.get("DTYPE", "bfloat16") self.block_size = int(env.get("BLOCK_SIZE", 64)) @@ -105,21 +100,20 @@ def read_from_env(self): self.bad_tokens = str(env.get("BAD_TOKENS", "-1")) self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1)) - # infer queue port - self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666)) - - # whether to use custom health checker - self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1)) - # Check the legality of requests self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 8192)) self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024)) + # whether to use custom health checker + self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1)) + # time interval for detecting whether the engine loop is normal during probing + self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10)) + # warmup self.use_warmup = int(os.getenv("USE_WARMUP", 0)) == 1 # uuid - self.shm_uuid = os.getenv("SHM_UUID", '') + self.shm_uuid = os.getenv("SHM_UUID", "") # use huggingface tokenizer self.use_hf_tokenizer = int(os.getenv("USE_HF_TOKENIZER", 0)) == 1 diff --git a/llm/server/server/engine/engine.py b/llm/server/server/engine/engine.py index a8eae90032..eea5560f47 100644 --- a/llm/server/server/engine/engine.py +++ b/llm/server/server/engine/engine.py @@ -43,7 +43,7 @@ def __init__(self, cfg): self._init_engine_flags() self.tqm_proc = self._start_task_queue_manager() - self.task_queue_manager = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port) + self.task_queue_manager = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_queue_port) start_time = time.time() self.infer_proc = self._start_infer_process() @@ -271,7 +271,7 @@ def _start_task_queue_manager(self): Returns: p: process handle """ - p = multiprocessing.Process(target=launch_task_queue_manager, args=(self.cfg.infer_port, self.cfg.mp_num)) + p = multiprocessing.Process(target=launch_task_queue_manager, args=(self.cfg.infer_queue_port, self.cfg.mp_num)) p.start() if p.is_alive(): model_server_logger.info("start tasks queue service successfully") diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index f94fac6e56..90b902cde1 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -29,7 +29,7 @@ from paddlenlp_ops import step_paddle from server.data.processor import DataProcessor -from server.engine.config import Config +from server.engine.config import get_global_config from server.utils import get_logger from server.engine.task_queue_manager import TaskQueueManager @@ -45,7 +45,7 @@ def __init__(self, args): # 2**63 - 1 self.MAX_INFER_SEED = 9223372036854775806 - self.config = Config() + self.config = get_global_config() self.model_cfg = self.config.get_model_config() self.format_print_configuration() @@ -63,7 +63,7 @@ def __init__(self, args): self.cache_kvs = {} self.init_inputs() - self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port) + self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_queue_port) model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}") if not os.path.exists(model_rank_path): @@ -354,6 +354,14 @@ def run(self): """ run infer """ + use_custom_health_checker = self.config.use_custom_health_checker + if use_custom_health_checker: + shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag() + engine_ready_check_flag_array[0] = 1 + shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag() + engine_healthy_recorded_time_array[0] = time.time() + infer_live_flag_shm = self.initialize_engine_live_flag() + flag_array = np.zeros([1], dtype=np.int32) shm_flag_broadcast = shared_memory.SharedMemory( name=self.config.get_unique_name("shm_pd_infer_flag_broadcast")) @@ -374,13 +382,6 @@ def run(self): dtype=flag_array.dtype, buffer=shm_flag_has_block_step.buf) - use_custom_health_checker = self.config.use_custom_health_checker - if use_custom_health_checker: - shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag() - engine_ready_check_flag_array[0] = 1 - shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag() - engine_healthy_recorded_time_array[0] = time.time() - infer_live_flag_shm = self.initialize_engine_live_flag() infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1], fill_value=4, dtype="int64") diff --git a/llm/server/server/triton_server.py b/llm/server/server/triton_server.py index 400bd64b85..ecfb13e051 100644 --- a/llm/server/server/triton_server.py +++ b/llm/server/server/triton_server.py @@ -55,8 +55,7 @@ def initialize(self, args): Triton initialization """ model_config = json.loads(args["model_config"]) - using_decoupled = pb_utils.using_decoupled_model_transaction_policy( - model_config) + using_decoupled = pb_utils.using_decoupled_model_transaction_policy(model_config) if not using_decoupled: raise pb_utils.TritonModelException( """the model `{}` can generate any number of responses per request, @@ -71,56 +70,44 @@ def initialize(self, args): GAUGE, ) self.metrics = { - "batch_size": - self.metric_family.Metric(labels={"batch_size": "batch_size"}), - "block_num": - self.metric_family.Metric(labels={"block_num": "block_num"}), - "max_batch_size": - self.metric_family.Metric( - labels={"max_batch_size": "max_batch_size"}), - "max_block_num": - self.metric_family.Metric( - labels={"max_block_num": "max_block_num"}), - "available_resource": - self.metric_family.Metric( - labels={"available_resource": "available_resource"}), + "batch_size": self.metric_family.Metric(labels={"batch_size": "batch_size"}), + "block_num": self.metric_family.Metric(labels={"block_num": "block_num"}), + "max_batch_size": self.metric_family.Metric(labels={"max_batch_size": "max_batch_size"}), + "max_block_num": self.metric_family.Metric(labels={"max_block_num": "max_block_num"}), + "available_resource": self.metric_family.Metric(labels={"available_resource": "available_resource"}), } - # if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false - # else use tritonserver's health checker, need set --http-port=${HTTP_PORT} - use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1)) - if use_custom_health_checker: - http_port = os.getenv("HTTP_PORT") - if http_port is None: - raise Exception("HTTP_PORT must be set") - from server.health_checker import start_health_checker - multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start() - time.sleep(1) - self.cfg = get_global_config() self.cfg.print(file="log/fastdeploy_init.info") self.req_senders = dict() self.cached_task_deque = deque() self.is_stopping = False + # if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false + # else use tritonserver's health checker, need set --http-port=${HTTP_PORT} + if self.cfg.use_custom_health_checker: + from server.health_checker import start_health_checker + multiprocessing.Process(target=start_health_checker, args=(self.cfg.http_port, )).start() + time.sleep(1) + self.engine = engine.Engine(self.cfg) - model_server_logger.info("Create engine success") + model_server_logger.info("create engine success") self.data_processor = DataProcessor() model_server_logger.info("create data processor success") - insert_task_thread = threading.Thread(target=self._insert_task, args=()) - insert_task_thread.daemon = True - insert_task_thread.start() + self.http_proc = None + self._launch_http_server() + model_server_logger.info("launch push server success") + + schedule_task_thread = threading.Thread(target=self._schedule_task, args=()) + schedule_task_thread.daemon = True + schedule_task_thread.start() send_output_thread = threading.Thread(target=self._send_output, args=()) send_output_thread.daemon = True send_output_thread.start() - self.http_process = None - self._launch_http_server() - - model_server_logger.info("Init triton server success") - + model_server_logger.info("init triton server success") def _launch_http_server(self): """ @@ -132,19 +119,18 @@ def _launch_http_server(self): http_py_path = os.path.join(current_dir_path, "http_server", http_py_file) http_cmd = f"python3 {http_py_path} --port={self.cfg.push_mode_http_port} " \ f"--workers={self.cfg.push_mode_http_workers} >log/launch_http.log 2>&1" - model_server_logger.info(f"Launch HTTP server for push mode, command:{http_cmd}") + model_server_logger.info(f"launch HTTP server for push mode, command:{http_cmd}") - self.http_process = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid) + self.http_proc = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid) time.sleep(3) - exit_code = self.http_process.poll() + exit_code = self.http_proc.poll() if exit_code is None: http_url = f"http://127.0.0.1:{self.cfg.push_mode_http_port}/v1/chat/completions" - model_server_logger.info(f"Launch HTTP server for push mode success, http_url:{http_url}") + model_server_logger.info(f"launch HTTP server for push mode success, http_url:{http_url}") else: error_msg = "\n Launch HTTP service for push mode failed in 3 seconds. " \ "Please check log/launch_http.log file \n" model_server_logger.error(error_msg) - model_server_logger.info("init push server success") def execute(self, requests): """ @@ -236,35 +222,32 @@ def execute(self, requests): self._update_metrics() - def _insert_task(self): + def _schedule_task(self): """ Insert task to engine thread, monitor cached_task_deque. if the engine has resource, insert task to engine """ - try: - while True: - if self.engine.available_batch() == 0: - time.sleep(0.001) - continue - if len(self.cached_task_deque) == 0: - time.sleep(0.001) - continue - if not self.engine.is_queue_empty(): + while True: + try: + if self.engine.available_batch() == 0 \ + or len(self.cached_task_deque) == 0 \ + or (not self.engine.is_queue_empty()): time.sleep(0.001) continue i_bs = 0 for _ in range(self.cfg.max_prefill_batch): - if len(self.cached_task_deque) == 0: - break - if self.engine.available_batch() == 0: + if len(self.cached_task_deque) == 0 \ + or self.engine.available_batch() == 0: break + while i_bs < self.cfg.max_batch_size: if self.engine.task_is_finished(i_bs): break i_bs += 1 if i_bs >= self.cfg.max_batch_size: break + input_token_num = len(self.cached_task_deque[-1]["input_ids"]) if not self.engine.is_resource_sufficient(input_token_num): break @@ -277,10 +260,9 @@ def _insert_task(self): _send_result({"error_msg": err_msg}, self.req_senders[task["req_id"]], 1) del self.req_senders[task["req_id"]] - model_server_logger.info("finish insert_task_push_mode thread") - except Exception as e: - model_server_logger.error("insert_task_push_mode thread exit " - f"unexpectedly, {e}. {str(traceback.format_exc())}") + except Exception as e: + model_server_logger.error(f"schedule task has error: {e}. {str(traceback.format_exc())}") + model_server_logger.info("schedule task thread exit") def _send_output(self): """ @@ -298,13 +280,13 @@ def _send_output(self): if return_all_tokens and "topk_tokens" in result: del result["topk_tokens"] result = self.data_processor.process_response(result) - model_server_logger.debug(f"Send result to client under push mode: {result}") + model_server_logger.debug(f"send result to client under push mode: {result}") _send_result([result], self.req_senders[req_id], is_end) if is_end == 1: del self.req_senders[req_id] self._update_metrics() except Exception as e: - model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc()))) + model_server_logger.error("unexcepted error happend: {}, {}".format(e, str(traceback.format_exc()))) def _update_metrics(self): @@ -336,8 +318,8 @@ def finalize(self): time.sleep(5) del self.engine - if self.http_process: - self.http_process.kill() + if self.http_proc: + self.http_proc.kill() model_server_logger.info("Triton service is terminated!")