Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng committed Sep 25, 2024
1 parent f186bc1 commit 8254cc5
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 101 deletions.
3 changes: 1 addition & 2 deletions llm/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ def parse_args():

result_path = f"output/{test_tag}.jsonl"
if os.path.exists(result_path):
logger.error(f"result file ({result_path}) already exists, exit")
exit()
logger.error(f"result file ({result_path}) already exists, overwrite it")
if not os.path.exists("output/"):
os.makedirs("output/")
logger.info(f"result_path: {result_path}")
Expand Down
46 changes: 20 additions & 26 deletions llm/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import sys
from datetime import datetime

import paddle
from paddlenlp.generation import GenerationConfig

from server.utils import model_server_logger


Expand All @@ -37,17 +39,13 @@ def read_from_env(self):
get the configuration from environment
"""
env = os.environ
self.model_dir = env.get(
"MODEL_DIR", "/opt/output/Serving/models")

self.model_dir = env.get("MODEL_DIR", "/opt/output/Serving/models")
if not self.model_dir:
raise Exception("The parameter MODEL_DIR is None.")
self.mp_num = int(env.get("MP_NUM", 8))
self.config_json_file = env.get("CONFIG_JSON_FILE", "config.json")
self.model_config_path = os.path.join(self.model_dir, self.config_json_file)
if env.get("FD_MODEL_CONFIG_PATH", None):
self.model_config_path = env.get("FD_MODEL_CONFIG_PATH")

# distributed config
self.model_config_path = os.path.join(self.model_dir,
env.get("CONFIG_JSON_FILE", "config.json"))
self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv")
if os.getenv("DISTRIBUTED_CONFIG", None):
self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG")
Expand All @@ -67,28 +65,25 @@ def read_from_env(self):
raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0")
self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0))

# server ports
self.grpc_port = int(os.getenv("GRPC_PORT", 8000))
self.http_port = int(os.getenv("HTTP_PORT", 8001))
self.metrics_port = int(os.getenv("METRICS_PORT", 8002))
self.infer_queue_port = int(os.getenv("INFER_QUEUE_PORT", 8005))
# if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled
self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", -1))

# max cached task num
self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128"))
# if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled
self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1"))
if self.push_mode_http_port > 0:
grpc_port = os.getenv("GRPC_PORT", None)
if grpc_port is None:
raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0")
self.grpc_port = int(grpc_port)

# http worker num
self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1"))
if self.push_mode_http_workers < 1:
raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive")

# Padlle commit id
import paddle
self.paddle_commit_id = paddle.version.commit

# time interval for detecting whether the engine loop is normal during probing
self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10))

# model config
self.dtype = env.get("DTYPE", "bfloat16")
self.block_size = int(env.get("BLOCK_SIZE", 64))
Expand All @@ -105,21 +100,20 @@ def read_from_env(self):
self.bad_tokens = str(env.get("BAD_TOKENS", "-1"))
self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1))

# infer queue port
self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666))

# whether to use custom health checker
self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))

# Check the legality of requests
self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 8192))
self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024))

# whether to use custom health checker
self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
# time interval for detecting whether the engine loop is normal during probing
self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10))

# warmup
self.use_warmup = int(os.getenv("USE_WARMUP", 0)) == 1

# uuid
self.shm_uuid = os.getenv("SHM_UUID", '')
self.shm_uuid = os.getenv("SHM_UUID", "")

# use huggingface tokenizer
self.use_hf_tokenizer = int(os.getenv("USE_HF_TOKENIZER", 0)) == 1
Expand Down
4 changes: 2 additions & 2 deletions llm/server/server/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, cfg):
self._init_engine_flags()

self.tqm_proc = self._start_task_queue_manager()
self.task_queue_manager = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
self.task_queue_manager = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_queue_port)

start_time = time.time()
self.infer_proc = self._start_infer_process()
Expand Down Expand Up @@ -271,7 +271,7 @@ def _start_task_queue_manager(self):
Returns:
p: process handle
"""
p = multiprocessing.Process(target=launch_task_queue_manager, args=(self.cfg.infer_port, self.cfg.mp_num))
p = multiprocessing.Process(target=launch_task_queue_manager, args=(self.cfg.infer_queue_port, self.cfg.mp_num))
p.start()
if p.is_alive():
model_server_logger.info("start tasks queue service successfully")
Expand Down
21 changes: 11 additions & 10 deletions llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from paddlenlp_ops import step_paddle

from server.data.processor import DataProcessor
from server.engine.config import Config
from server.engine.config import get_global_config
from server.utils import get_logger
from server.engine.task_queue_manager import TaskQueueManager

Expand All @@ -45,7 +45,7 @@ def __init__(self, args):
# 2**63 - 1
self.MAX_INFER_SEED = 9223372036854775806

self.config = Config()
self.config = get_global_config()
self.model_cfg = self.config.get_model_config()
self.format_print_configuration()

Expand All @@ -63,7 +63,7 @@ def __init__(self, args):
self.cache_kvs = {}
self.init_inputs()

self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)
self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_queue_port)

model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
if not os.path.exists(model_rank_path):
Expand Down Expand Up @@ -354,6 +354,14 @@ def run(self):
"""
run infer
"""
use_custom_health_checker = self.config.use_custom_health_checker
if use_custom_health_checker:
shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag()
engine_ready_check_flag_array[0] = 1
shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag()
engine_healthy_recorded_time_array[0] = time.time()
infer_live_flag_shm = self.initialize_engine_live_flag()

flag_array = np.zeros([1], dtype=np.int32)
shm_flag_broadcast = shared_memory.SharedMemory(
name=self.config.get_unique_name("shm_pd_infer_flag_broadcast"))
Expand All @@ -374,13 +382,6 @@ def run(self):
dtype=flag_array.dtype,
buffer=shm_flag_has_block_step.buf)

use_custom_health_checker = self.config.use_custom_health_checker
if use_custom_health_checker:
shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag()
engine_ready_check_flag_array[0] = 1
shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag()
engine_healthy_recorded_time_array[0] = time.time()
infer_live_flag_shm = self.initialize_engine_live_flag()
infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1],
fill_value=4,
dtype="int64")
Expand Down
104 changes: 43 additions & 61 deletions llm/server/server/triton_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def initialize(self, args):
Triton initialization
"""
model_config = json.loads(args["model_config"])
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
model_config)
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(model_config)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
Expand All @@ -71,56 +70,44 @@ def initialize(self, args):
GAUGE,
)
self.metrics = {
"batch_size":
self.metric_family.Metric(labels={"batch_size": "batch_size"}),
"block_num":
self.metric_family.Metric(labels={"block_num": "block_num"}),
"max_batch_size":
self.metric_family.Metric(
labels={"max_batch_size": "max_batch_size"}),
"max_block_num":
self.metric_family.Metric(
labels={"max_block_num": "max_block_num"}),
"available_resource":
self.metric_family.Metric(
labels={"available_resource": "available_resource"}),
"batch_size": self.metric_family.Metric(labels={"batch_size": "batch_size"}),
"block_num": self.metric_family.Metric(labels={"block_num": "block_num"}),
"max_batch_size": self.metric_family.Metric(labels={"max_batch_size": "max_batch_size"}),
"max_block_num": self.metric_family.Metric(labels={"max_block_num": "max_block_num"}),
"available_resource": self.metric_family.Metric(labels={"available_resource": "available_resource"}),
}

# if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false
# else use tritonserver's health checker, need set --http-port=${HTTP_PORT}
use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
if use_custom_health_checker:
http_port = os.getenv("HTTP_PORT")
if http_port is None:
raise Exception("HTTP_PORT must be set")
from server.health_checker import start_health_checker
multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start()
time.sleep(1)

self.cfg = get_global_config()
self.cfg.print(file="log/fastdeploy_init.info")
self.req_senders = dict()
self.cached_task_deque = deque()
self.is_stopping = False

# if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false
# else use tritonserver's health checker, need set --http-port=${HTTP_PORT}
if self.cfg.use_custom_health_checker:
from server.health_checker import start_health_checker
multiprocessing.Process(target=start_health_checker, args=(self.cfg.http_port, )).start()
time.sleep(1)

self.engine = engine.Engine(self.cfg)
model_server_logger.info("Create engine success")
model_server_logger.info("create engine success")

self.data_processor = DataProcessor()
model_server_logger.info("create data processor success")

insert_task_thread = threading.Thread(target=self._insert_task, args=())
insert_task_thread.daemon = True
insert_task_thread.start()
self.http_proc = None
self._launch_http_server()
model_server_logger.info("launch push server success")

schedule_task_thread = threading.Thread(target=self._schedule_task, args=())
schedule_task_thread.daemon = True
schedule_task_thread.start()
send_output_thread = threading.Thread(target=self._send_output, args=())
send_output_thread.daemon = True
send_output_thread.start()

self.http_process = None
self._launch_http_server()

model_server_logger.info("Init triton server success")

model_server_logger.info("init triton server success")

def _launch_http_server(self):
"""
Expand All @@ -132,19 +119,18 @@ def _launch_http_server(self):
http_py_path = os.path.join(current_dir_path, "http_server", http_py_file)
http_cmd = f"python3 {http_py_path} --port={self.cfg.push_mode_http_port} " \
f"--workers={self.cfg.push_mode_http_workers} >log/launch_http.log 2>&1"
model_server_logger.info(f"Launch HTTP server for push mode, command:{http_cmd}")
model_server_logger.info(f"launch HTTP server for push mode, command:{http_cmd}")

self.http_process = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid)
self.http_proc = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid)
time.sleep(3)
exit_code = self.http_process.poll()
exit_code = self.http_proc.poll()
if exit_code is None:
http_url = f"http://127.0.0.1:{self.cfg.push_mode_http_port}/v1/chat/completions"
model_server_logger.info(f"Launch HTTP server for push mode success, http_url:{http_url}")
model_server_logger.info(f"launch HTTP server for push mode success, http_url:{http_url}")
else:
error_msg = "\n Launch HTTP service for push mode failed in 3 seconds. " \
"Please check log/launch_http.log file \n"
model_server_logger.error(error_msg)
model_server_logger.info("init push server success")

def execute(self, requests):
"""
Expand Down Expand Up @@ -236,35 +222,32 @@ def execute(self, requests):

self._update_metrics()

def _insert_task(self):
def _schedule_task(self):
"""
Insert task to engine thread, monitor cached_task_deque.
if the engine has resource, insert task to engine
"""
try:
while True:
if self.engine.available_batch() == 0:
time.sleep(0.001)
continue
if len(self.cached_task_deque) == 0:
time.sleep(0.001)
continue
if not self.engine.is_queue_empty():
while True:
try:
if self.engine.available_batch() == 0 \
or len(self.cached_task_deque) == 0 \
or (not self.engine.is_queue_empty()):
time.sleep(0.001)
continue

i_bs = 0
for _ in range(self.cfg.max_prefill_batch):
if len(self.cached_task_deque) == 0:
break
if self.engine.available_batch() == 0:
if len(self.cached_task_deque) == 0 \
or self.engine.available_batch() == 0:
break

while i_bs < self.cfg.max_batch_size:
if self.engine.task_is_finished(i_bs):
break
i_bs += 1
if i_bs >= self.cfg.max_batch_size:
break

input_token_num = len(self.cached_task_deque[-1]["input_ids"])
if not self.engine.is_resource_sufficient(input_token_num):
break
Expand All @@ -277,10 +260,9 @@ def _insert_task(self):
_send_result({"error_msg": err_msg},
self.req_senders[task["req_id"]], 1)
del self.req_senders[task["req_id"]]
model_server_logger.info("finish insert_task_push_mode thread")
except Exception as e:
model_server_logger.error("insert_task_push_mode thread exit "
f"unexpectedly, {e}. {str(traceback.format_exc())}")
except Exception as e:
model_server_logger.error(f"schedule task has error: {e}. {str(traceback.format_exc())}")
model_server_logger.info("schedule task thread exit")

def _send_output(self):
"""
Expand All @@ -298,13 +280,13 @@ def _send_output(self):
if return_all_tokens and "topk_tokens" in result:
del result["topk_tokens"]
result = self.data_processor.process_response(result)
model_server_logger.debug(f"Send result to client under push mode: {result}")
model_server_logger.debug(f"send result to client under push mode: {result}")
_send_result([result], self.req_senders[req_id], is_end)
if is_end == 1:
del self.req_senders[req_id]
self._update_metrics()
except Exception as e:
model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
model_server_logger.error("unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))


def _update_metrics(self):
Expand Down Expand Up @@ -336,8 +318,8 @@ def finalize(self):
time.sleep(5)

del self.engine
if self.http_process:
self.http_process.kill()
if self.http_proc:
self.http_proc.kill()
model_server_logger.info("Triton service is terminated!")


Expand Down

0 comments on commit 8254cc5

Please sign in to comment.