diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/README.md" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/README.md" new file mode 100644 index 00000000..687c6ada --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/README.md" @@ -0,0 +1,638 @@ + + +# **作品介绍:** + +团队名:起风了 + +## 一、优化策略 + +### 1. 超参数调优 + +调整decode_batch_size: [128],修改llama_7b_kbk_pa_dyn.yaml文件,时间来到680s左右 + +``` +model_config: + model_name: 'llama_7b' + # max_generate_length: 600 ##快几秒 + max_generate_length: 4096 + end_token: 2 + seq_length: [4096] + vocab_size: 32000 + prefill_batch_size: [1] + # decode_batch_size: [1] + decode_batch_size: [128] + zactivate_len: [512, 1024, 2048, 4096] + model_type: 'dyn' + seq_type: 'static' + batch_waiting_time: 0.0 + decode_batch_waiting_time: 0.0 + batching_strategy: 'continuous' + current_index: False + page_attention: True + model_dtype: "DataType.FLOAT32" + pad_token_id: 0 + backend: 'kbk' # 'ge' + model_cfg_path: '/home/ma-user/work/mindformers/configs/llama2/predict_llama2_7b.yaml' + +serving_config: + agent_ports: [16002] + start_device_id: 0 + server_ip: '127.0.0.1' + server_port: 8835 + +pa_config: + num_blocks: 1024 + block_size: 16 + decode_seq_length: 4096 + +tokenizer: + type: LlamaTokenizer + vocab_file: '/home/ma-user/work/checkpoint_download/llama2/tokenizer.model' + +basic_inputs: + type: LlamaBasicInputs + +extra_inputs: + type: LlamaExtraInputs + +warmup_inputs: + type: LlamaWarmupInputs +``` + + + +### 2. 调整prefill和decoding任务调度 + +发现初始的版本decode阶段还是存在大量padding, 没有利用好显卡的并行计算优势,AIcore利用率大致为60%左右 + +为了增加decode的并行度,减少无效padding, 将prefill请求优先级调整到最高,decoding任务优先级最低,这样可以使得decode阶段的并行请求更多,尽可能打满整个batch,AIcore利用率提升至为80%左右,时间来到610s左右 + +具体修改如下: + +1)修改agent_multi_post_method.py文件 + +``` +def start_agent_socket_server(i, cfg: ServingConfig, startup_queue): + logging.basicConfig(level=logging.ERROR, + filename=f"./output/agent_{i}.log", + filemode='w', + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') + """启动agent进程, 由_agent_process进行调用, 创建agent进程""" + if IMPORT_LITE_FAILED: + logging.warning("import mindspore_lite failed, using kbk backend.") + work_agent = WorkAgent(i, cfg) # 创建一个WorkAgent实例,传入当前agent的索引和配置。 + + agent_ports = cfg.serving_config.agent_ports + agent_ip = cfg.serving_config.agent_ip + agent_address = (agent_ip, agent_ports[i]) + # 设置当前agent的地址(IP和端口)。 + print(agent_address) + + parent_process = psutil.Process(os.getppid()) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(agent_address) + server.listen(50) # 开始监听连接,允许最多50个待处理连接 + + startup_queue.put(i) + + # 绑定method + # print("start agent socket server in rank{}".format(i), flush=True) + # logging.info("Agent socket server started on {}".format(agent_address)) + + task_queue = queue.PriorityQueue() + + def handle_client(conn): + while True: + if not parent_process.is_running(): + logging.warning( + f"detect parent pid={parent_process.pid} has exited, child begin to exit") + conn.close() + return + + try: + data = conn.recv(4096) + if not data: + break + data = data.decode() + # logging.debug(f"Data received: {data}") + + if data.startswith('#') or data.startswith('*') or data.startswith('e') or data.startswith('r'): + priority = 0 # 高优先级 + else: + priority = 1 # 低优先级 + + task_queue.put((priority, data, conn)) + # logging.info(f"Task added to queue with priority {priority}: {data}") + + except ConnectionResetError: + break + except RuntimeError as e: + logging.error(f"Runtime error: {e}") + conn.sendall("2".encode()) + break + + def process_tasks(): + while True: + priority, data, conn = task_queue.get() + # logging.info(f"Processing task with priority {priority}: {data}") + + if data.startswith('#'): + if work_agent.status & AgentStatus.unconnected == AgentStatus.unconnected: + data = data[1:] + work_agent.shm_names = data.split(",") + work_agent.status = AgentStatus.connected + # logging.info("Connected successfully") + conn.sendall("success".encode()) + else: + # logging.info("Connection failed") + conn.sendall("failed".encode()) + + elif data.startswith('*'): + # 全量推理 + work_agent.is_prefill = True + data = data[1:] + shape_strs = data.split(",") + input_shapes = [] + for shape_str in shape_strs: + shape = list(map(int, shape_str.split(" "))) + input_shapes.append(shape) + _, _ = work_agent.predict(shape_list=input_shapes) + if i == 0: + conn.sendall("1".encode()) + + elif data.startswith('a'): + # 增量推理 + decode_data = data.split('_') + # 增加PA的判断 + current_batch_dyn = int(decode_data[-4]) if cfg.model_config.page_attention else int( + decode_data[-2]) + batch_valid_flag = [] + batch_valid = decode_data[-3] if cfg.model_config.page_attention else decode_data[-1] + for ele in batch_valid.split(" "): + batch_valid_flag.append(int(ele)) + # 增加 block_tables和slot_mapping 的 shape + input_shapes = [] + if cfg.model_config.page_attention: + for shape_str in [decode_data[-2], decode_data[-1]]: + shape = list(map(int, shape_str.split(" "))) + input_shapes.append(shape) + work_agent.is_prefill = False + _, _ = work_agent.predict(current_batch=current_batch_dyn, batch_valid_flag=batch_valid_flag, + shape_list=input_shapes) + if i == 0: + conn.sendall("1".encode()) + elif data.startswith('e'): + if work_agent.status & AgentStatus.busy == AgentStatus.busy: + # logging.info("Agent is busy") + conn.sendall("busy".encode()) + # else: + work_agent.status = AgentStatus.unconnected + # logging.info("Agent is free") + conn.sendall("free".encode()) + + elif data.startswith('r'): + work_agent.status = AgentStatus.unconnected + # logging.info("Reset successful") + conn.sendall("success".encode()) + + threading.Thread(target=process_tasks, daemon=True).start() + + while True: + if not parent_process.is_running(): + logging.warning(f"detect parent pid={parent_process.pid} has exited, child begin to exit") + server.close() + return + conn, client_addr = server.accept() + # logging.info(f"Connection accepted from {client_addr}") + threading.Thread(target=handle_client, args=(conn,), daemon=True).start() + +``` + +2)仅依靠修改优先级会导致首token得到的是无效token,同时最后会少输出一个token, 不过总体token数是正常的,如图所示,human就是读取的无效token产生的 + +![image-20241031164326496](C:\Users\ly\AppData\Roaming\Typora\typora-user-images\image-20241031164326496.png) + +分析原因是**共享内存读取与写入的时间不同步**(答辩时再具体解析),为此还要修改如下代码: + +model_init_multimodel.py文件中, call函数需要增加两行代码(位于514,515行),即首token需要返回空list(这两行在验证精度时需要注释掉,精度是没问题的,主要是因为调整prefill和decode优先级后,精度验证时保存的文件中的顺序也不一样了,所以精度验证时要关闭优先级调整策略): + +``` +def call(self, shms: List, input_ids, current_index, + valid_length, init_reset, is_first_iteration, valid_batch_flag, extra_inputs=None, + current_batch_size=None, **kwargs): + + ............ + ............ + for item in self.agent_stubs: + item.sendall(shapes_str) + recv_data = self.agent_stubs[0].recv(1, socket.MSG_WAITALL).decode() + # if not recv_data=="1": + # recv_data = self.agent_stubs[0].recv(1, socket.MSG_WAITALL).decode() + result = [] + if recv_data == "2": + for _ in decode_index_list: + # result.append(int(Baseconfig.end_token)) + result.append((int(-1),0)) + print("--------------------predict failed, abandon current prompt, please try again----------------") + logging.error("predict failed, abandon current prompt, please try again") + return result, 1 + + ####测试精度时需要注释下面两行 + if is_first_iteration: + return result, 1 + ############ + + for decode_index in decode_index_list: + tmp = np.ndarray((decode_index + 1,), dtype=np.int32, buffer=shms[5].buf) + tmp_logprob = np.ndarray((decode_index + 1,), dtype=np.float64, buffer=shms[6].buf) + result.append((int(tmp[decode_index:decode_index + 1]), float(tmp_logprob[decode_index:decode_index + 1]))) + + logging.info("--------------------callV3 result value is {} ".format(result)) + logging.info("model.call time is {} ".format((time.time() - time_start) * 1000)) + return result, 1 + +``` + +3)同时master.py需要进行如下调整,修改_postprocess()函数 + +``` + def _postprocess(self, + outputs: List[tuple], + entry_metadata_list: List[EntryMetaData], + index_list: List[int] = None, + skip_inference=False) -> List[ResponseOutput]: + + end_token = self.config.model_config.end_token # for debug + + ################ 首token是无效token,更新状态后返回,不要解码 + if len(outputs)==0 or outputs==[]: + self.scheduler.upate_entries_after_one_step_after_prefill(end_token, index_list) + return None + ################# + output_tokens = [] + output_logprob = [] + ### 整个batch,一个迭代的输出, 一个迭代输出一个token + for output_tup in outputs: + output_tokens.append(output_tup[0]) + output_logprob.append(output_tup[1]) + ......... + +``` + +4)schedule.py文件中,创建upate_entries_after_one_step_after_prefill函数,增加如下代码: + +``` +def upate_entries_after_one_step_after_prefill(self, eos_id: int, index_list: List[int] = None): + """update status after ever iteration""" + # optimize prefill multi-batch later + if index_list is not None: + # idx: index_list and outputs data index, index: batch list index. + for idx, index in enumerate(index_list): + self.running_request_list[index].is_prompt = False + # invalid prompt + if self.running_request_list[index].get_entry_data().get_status() == EntryStatus.PADDING_INVAILED: + continue + + if self.running_request_list[index].get_entry_data().get_status() == EntryStatus.INPUT_OUTOFRANGE: + update_token = INPUT_OUT_OF_TOKEN[0] + elif self.running_request_list[index].get_entry_data().get_status() == EntryStatus.EMPTY_PROMPT_TOKEN: + update_token = INPUT_EMPTY_TOKEN[0] + else: + continue + + self.running_request_list[index].get_entry_data().updata_output_tokens(update_token) + # valid prompt 区分PA处理 + if self.config.model_config.page_attention: + self._finished_pa_request(index, update_token, eos_id) + else: + self._finished_request(index, update_token, eos_id) +``` + +5)修改完以上代码,输出内容正常了 + +![image-20241031170331451](C:\Users\ly\AppData\Roaming\Typora\typora-user-images\image-20241031170331451.png) + + + +### 3. 后处理和模型推理异步进行 + +初始版本predict(推理过程)和post_process(解码过程)是串行的,事实上模型推理完不需要等待后处理完成 + +因此将predict和post_process解耦,变成异步进行的方式,时间来到 606s(提升4s左右) + +修改master.py文件, + +``` +import asyncio +from mindspore_serving.master.request_resister_engine import RequestEngine +class AsyncMaster(Master): + def __init__( + self, + config: ServingConfig, + request_engine: RequestEngine + ): + super().__init__(config) + self.detokenizer_que = asyncio.Queue() ############ add + self.request_engine = request_engine ############ add + + ############### add + def send_post_process(self,output,entry_metadata_list,index_list,skip_inference=False): + self.detokenizer_que.put_nowait( + BatchTokenIdOut(output,entry_metadata_list,index_list,skip_inference) + ) + ############### + ............. + + async def _run_workers_async(self, current_batch_size, entry_metadata_list): + # e_t_e_time = time.time() + + prompt_token_empty_list = self._check_prompt_token_empty(entry_metadata_list, + self.config.model_config.pad_token_id) + # logging.debug("prompt token empty list index_list {}".format(prompt_token_empty_list)) + if len(prompt_token_empty_list) > 0: + ############# add + self.send_post_process([INPUT_EMPTY_TOKEN], entry_metadata_list=entry_metadata_list, + index_list=prompt_token_empty_list, + skip_inference=True) + # return self._postprocess([INPUT_EMPTY_TOKEN], entry_metadata_list=entry_metadata_list, + # index_list=prompt_token_empty_list, + # skip_inference=True) + + # check prefill out of range data + out_of_range_index_list = self._check_prompt_out_of_range_index_list(entry_metadata_list) + # logging.debug("out of range prompt index_list {}".format(out_of_range_index_list)) + if len(out_of_range_index_list) > 0: + ######################### add + self.send_post_process([INPUT_OUT_OF_TOKEN], entry_metadata_list=entry_metadata_list, + index_list=out_of_range_index_list, + skip_inference=True) + # return self._postprocess([INPUT_OUT_OF_TOKEN], entry_metadata_list=entry_metadata_list, + # index_list=out_of_range_index_list, + # skip_inference=True) + + # filter prompt data batch list + ## entry_metadata_list就是schedule返回的decode_batchsize个请求 + ### 这里根据 prefill_batchsize从中取出prefill_batchsize个请求 + input_entry_metadata_list, index_list = self._get_prompt_batch_list(entry_metadata_list) + # logging.debug("_get_prompt_batch_list prompt index_list {}, input_entry_metadata_list {}" + # .format(index_list, input_entry_metadata_list)) + # prefill predict + if len(input_entry_metadata_list) > 0: ### 只要running_request_list存在刚进来的prompt, 就在先要进行prefill + # logging.debug('prefill len of input entry_metadata_list is {}'.format(len(input_entry_metadata_list))) + # predict + output = self.worker.predict(current_batch_size, entry_metadata_list=input_entry_metadata_list) + else: # decode predict + input_entry_metadata_list = entry_metadata_list + index_list = None + # logging.debug('decode len of input entry_metadata_list is {}'.format(len(input_entry_metadata_list))) + output = self.worker.predict(current_batch_size, entry_metadata_list=input_entry_metadata_list) + + # post_process_time = time.time() + ################################# add + self.send_post_process(output, entry_metadata_list=entry_metadata_list, index_list=index_list) + # result = self._postprocess(output, entry_metadata_list=entry_metadata_list, index_list=index_list) + # logging.info('post_process_time time is {}'.format((time.time() - post_process_time) * 1000)) + # logging.info('e-to-e time is {}'.format((time.time() - e_t_e_time) * 1000)) + # return result + + ................ + ########## add + async def handle_detokenization_loop(self): + while True: + try: + recv_obj = await self.detokenizer_que.get() # BatchTokenIdOut + post_process_time = time.time() + request_outputs = self._postprocess(recv_obj.output, entry_metadata_list=recv_obj.entry_metadata_list, index_list=recv_obj.index_list,skip_inference=recv_obj.skip_inference) + logging.info('post_process_time time is {}'.format((time.time() - post_process_time) * 1000)) + + # Put the outputs into the corresponding streams. + if request_outputs is not None: + for request_output in request_outputs: + self.request_engine.process_request_output(request_output) + self.detokenizer_que.task_done() + except Exception as e: + print(e) +``` + +同时修改llm_server_post.py文件 + +``` +class LLMServer: + def __init__(self, config: ServingConfig): + self.request_engine = RequestEngine() + self.background_loop = None + self.master = AsyncMaster(config, self.request_engine) # liuyang + # self.master = AsyncMaster(config) + self.status = 0 + self.config = config + + @property + def is_running(self) -> bool: + return self.background_loop is not None + + async def run_loop(self): + while self.status: + await self.step() + await asyncio.sleep(0) + + def start_background_loop(self) -> None: + # todo + self.status = 1 + """Start the background loop.""" + if self.is_running: + raise RuntimeError("Background loop is already running.") + self.background_loop = asyncio.get_event_loop().create_task(self.run_loop()) + asyncio.get_event_loop().create_task(self.master.handle_detokenization_loop()) # add +``` + + + +### 4. 根据input长度排序推理 + +为了更好的模型并行化,performance_serving发送1500条请求之前,先根据input的长度对1500条请求进行排序,按照长度从短到长发送请求,这一步时间优化约2-3s,时间来到604s + +修改test_serving_performance.py + +``` +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="test serving performance") + parser.add_argument("-X", "--qps", help='x req/s', required=True, type=float) + parser.add_argument("-P", "--port", help='port, default is 8000', required=True) + parser.add_argument("-O", "--out_dir", help='dir for saving results', required=True) + parser.add_argument("-T", "--test_time", help='test all time, default 1h', required=False, type=int, default=3600) + args = parser.parse_args() + with open("./alpaca_5010.json") as f: + alpaca_data = json.loads(f.read()) + INPUTS_DATA = [] + OUTPUTS_DATA = [] + count = 0 + input_length = [] + for data in alpaca_data: + count+=1 + if count>1500: + break + input_ = data["instruction"] + ":" + data["input"] if data["input"] else data["instruction"] + INPUTS_DATA.append(input_) + OUTPUTS_DATA.append(data["output"]) + input_length.append(len(input_)) + indexes = np.argsort(input_length) + INPUTS_DATA = [INPUTS_DATA[i] for i in indexes] + OUTPUTS_DATA = [OUTPUTS_DATA[i] for i in indexes] + test_main(args.port, INPUTS_DATA, OUTPUTS_DATA, args.qps, args.out_dir, args.test_time) +``` + + + +### 5. 算子替换 + +将agent_multi_post_method.py中数据预处理的np.concatenate改为np.pad, 这一步提升大概6s左右,时间来到598s + +``` +# decode 时,先将 shape 与 prefill 改为一致 +if input_ids.shape[1] == 1: + # input_ids = np.concatenate((input_ids, np.zeros((input_ids.shape[0], seq_length - 1))), axis=1) + input_ids = np.pad(input_ids,((0,0),(0,seq_length - 1)),'constant',constant_values = (0,0)) # add +``` + + + +## 二、超参数配置: + +### 1. llm_serving + +修改llama_7b_kbk_pa_dyn.yaml文件中的decode_batch_size + +(1)**测试推理时延时,修改如下参数:** + + decode_batch_size: [128] + +(2)**验证精度时,只能设置为1:** + + decode_batch_size: [1] + +### 2. performance_serving + +修改performance_serving中的 test.sh + +(1)测试推理时延时,设置为 -x 10 -T 150: + +``` +python test_serving_performance.py -X 10 -P 8835 -O "./" -T 150 +``` + +备注:测试耗时时需要**将test_serving_performance_sort.py文件中的内容复制到test_serving_performance.py** + +(2)验证精度时,设置为: + +``` +python test_serving_performance.py -X 0.1 -P 8835 -O "./" -T 5000 +``` + +备注:验证精度时需要**将test_serving_performance_raw.py文件中的内容复制到test_serving_performance.py** + + + +## 三、 **推理结果:** + +耗时:推理1500条数据,总耗时598 s + +备注: + +1)推理前必须先采取多条请求预热; + +例如运行3次以上: + +curl 127.0.0.1:8835/models/llama2/generate \ + +-X POST \ + +-d '{"inputs":" I love Beijing, because","parameters":{"max_new_tokens":16, "do_sample":"False", "return_full_text":"True"}, "stream":"True"}' \ + +-H 'Content-Type: application/json' + +2)每次推理会有1-2s的时间波动,经过多次测量,测量时间均在598-600s之间 + +![image-20241029210707329](C:\Users\ly\AppData\Roaming\Typora\typora-user-images\image-20241029210707329.png) + + + +## 四、 精度验证: + +500条数据精度比对均通过: + +备注: + +1) 验证精度时需要将llmserving中的 agent_multi_post_method_save_logits.py中的内容复制并替换agent_multi_post_method.py(记得保存)中的内容 + +2)注释掉model_init_multimodel.py文件中call函数增加的两行代码(514行,515行,原因上面也说了,验证精度时不能调整prefill和decode的优先级,会导致保存的顺序不一致,无法对比,但是调度是不会影响精度的,推理1500条的日志内容也完全正确) + +![image-20241031171324500](C:\Users\ly\AppData\Roaming\Typora\typora-user-images\image-20241031171324500.png) + +3)由于采用了perdict和postprocess异步方式推理,目前还不太稳定,测试精度时可能偶尔存在丢包的情况,这个时候请重新运行一下,以500条数据成功推理的结果为准 + +4) performerce_serving 测试精度时,需要将test_serving_performance_raw.py文件中的内容复制到test_serving_performance.py 文件中,(test_serving_performance_raw.py文件中没有对input长度进行排序, 因为排序会导致文件保存顺序不一致,因此验证精度时不能排序,但是推理顺序理论上并不会影响精度) + +![image-20241030110534409](C:\Users\ly\AppData\Roaming\Typora\typora-user-images\image-20241030110534409.png) + + + +打点分析: + +``` +[mindformers call-chain] +-> GenerationMixin.generate // 循环推理,测定 25ms/token + -> GenerationMixin.infer // 推理一个token单位 + -> GenerationMixin.forward (*) // 推测 ~22.5ms + -> LlamaForCausalLM.construct // if prefill + -> GenerationMixin._incremental_infer // if decode + -> LlamaForCausalLM.construct + -> GenerationMixin.postprocess // 推测 ~2.5ms + +[llm-serving call-chain] +-> LLMServer.generate_answer // 推理一个请求单位 + -> LLMServer.register_request + -> LLMServer.start_background_loop + -> LLMServer.run_loop + -> LLMServer.step + -> AsyncMaster.step_async // 决定 batching 策略 + -> AsyncMaster._run_workers_async // 测定 40ms/token -> 优化后 27ms/token + -> Worker.predict // 测定 39.55ms + - Worker._predict + -> DisModel.call // 测定 39.47ms + - shared_mem::write + - tcp::sendall + -> start_agent_socket_server // 测定 63.4ms(prefill) / 37.7ms(decode) + - tcp::recv + - WorkAgent.predictc + - shared_mem::read + - agent pre-process // 测定 12.0ms + - WorkAgent.predict_for_kbk // 测定 22.5ms + -> GenerationMixin.forward // 流程同上 mindformers (*) + - WorkAgent.do_post_sampling // 测定 3.04ms + - shared_mem::write + - tcp::sendall + - tcp::recv + - shared_mem::read + - AsyncMaster._postprocess // 测定 0.83ms +``` + + + +## 五、 运行环境说明: + +本作品直接使用比赛说明中配置的环境,不需要安装其他环境 + + + +## 六、 代码以及npy文件路径: + +**压缩包文件路径(可直接下载)** + +所有文件已经打包成一个文件 file_20241031.zip,最新提交obs路径如下: + +https://aics2024.obs.cn-southwest-2.myhuaweicloud.com/file_20241031.zip + +其中包括llm-serving、performance_serving代码、精度验证结果file_npy文件,mindformers(这个库没有做改动,可以直接用官方的) + diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/agent_multi_post_method.py" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/agent_multi_post_method.py" new file mode 100644 index 00000000..b0859f62 --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/agent_multi_post_method.py" @@ -0,0 +1,1024 @@ +"""agent""" +import copy +import signal +import socket +import time +import psutil +import os +import numpy as np +from concurrent.futures import ThreadPoolExecutor, as_completed, wait +from multiprocessing import Process, shared_memory +try: + import mindspore_lite as mslite + IMPORT_LITE_FAILED = False +except ImportError: + IMPORT_LITE_FAILED = True +from mindspore.common.tensor import Tensor +from mindspore_serving.serving_utils.err_code import AgentStatus +from mindspore_serving.models.post_sampling.topk import post_sampling, softmax_np +from mindspore_serving.sub_process.sub_process import listen_agents_after_startup +from mindspore_serving.config.config import ServingConfig +from mindspore_serving.models.build_inputs import build_inputs + +import mindspore +from mindformers import AutoConfig, AutoModel +from tools.post_sampling_model import temperature_TopK, ArgmaxPost +import logging +# from mindformers.tools.logger import logger +pool = ThreadPoolExecutor(max_workers=20, thread_name_prefix='test_thread') +import queue +import threading + +def load_model_for_kbk(cfg: ServingConfig, rank_id: int, device_id: int): + # 加载模型 + model_config = cfg.model_config + + # 0: mindspore.GRAPH_MODE, 1: mindspore.PYNATIVE_MODE + mindspore.set_context(mode=mindspore.GRAPH_MODE, device_id=device_id, device_target="Ascend") + # mindspore.set_context(inter_op_parallel_num=8) ### 无提升 + mindspore.set_context(enable_graph_kernel=True) + model = AutoModel.from_config(model_config.model_cfg_path) + model.set_train(False) + + return model + + +def load_model_for_ge(cfg: ServingConfig, rank_id: int, device_id: int): + # 加载模型 + model_path = cfg.model_path + model_config = cfg.model_config + context = mslite.Context() + + warmup_func = build_inputs(cfg.warmup_inputs, module_type='warmup_inputs') + context.ascend.device_id = device_id + context.ascend.rank_id = rank_id + context.ascend.provider = "ge" + context.target = ["Ascend"] + # 单模型 + if len(model_path.decode_model) == 0: + model0 = mslite.Model() + model0.build_from_file(model_path.prefill_model[0], mslite.ModelType.MINDIR, context, model_path.prefill_ini[0]) + model1 = None + return model0, model1 + # rank_table_file放在config_file中 + all_models = [mslite.Model()] # prefill + # decode + for _ in model_path.decode_ini: + all_models.append(mslite.Model()) + model_group = mslite.ModelGroup(mslite.ModelGroupFlag.SHARE_WEIGHT) + model_group.add_model(all_models) + all_models[0].build_from_file(model_path.prefill_model[rank_id], mslite.ModelType.MINDIR, context, + model_path.prefill_ini[0]) + # warm up prefill model + prefill_batch_size = model_config.prefill_batch_size[0] if len(model_config.prefill_batch_size) > 0 else 1 + # 加入PA判断 + if model_config.page_attention: + prefill_seq_length = model_config.seq_length[-1] + inc_seq_len = cfg.pa_config.decode_seq_length + prefill_inputs_list = warmup_func.get_warmup_inputs(seq_length=prefill_seq_length, + batch_size=prefill_batch_size, + full_model=True, + use_current_index=model_config.current_index, + # 这里需要考虑加入use_current_index测试 + page_attention=model_config.page_attention, + zactivate_len=model_config.zactivate_len, + decode_seq_length=inc_seq_len, + block_size=cfg.pa_config.block_size) + + else: + prefill_seq_length = model_config.seq_length[0] if len(model_config.seq_length) > 0 else 2048 + prefill_inputs_list = warmup_func.get_warmup_inputs(seq_length=prefill_seq_length, + batch_size=prefill_batch_size, + full_model=True, + use_current_index=model_config.current_index, + page_attention=model_config.page_attention, + zactivate_len=model_config.zactivate_len, + model_type=model_config.model_type) + prefill_lite_inputs = [mslite.Tensor(np.ascontiguousarray(item)) for item in prefill_inputs_list] + if rank_id == 0: + for item in prefill_lite_inputs: + print("prefill item ") + + all_models[0].predict(prefill_lite_inputs) + + if len(model_path.decode_ini) != len(model_config.zactivate_len): + # padding invalid act_len list + model_config.zactivate_len = [2 for _ in range(len(model_path.decode_ini))] + for i in range(len(model_path.decode_ini)): + act_len = model_config.zactivate_len[i] + + + if len(model_config.decode_batch_size) == 0: + raise ValueError("length of model_config.decode_batch_size should at least be 1, but got 0") + warm_batch_size = model_config.decode_batch_size[0] if len(model_config.decode_batch_size) > 0 else 1 + warm_seq_length = 1 + if model_config.page_attention: + + inc_seq_len = cfg.pa_config.decode_seq_length + decode_inputs_list = warmup_func.get_warmup_inputs(seq_length=warm_seq_length, + batch_size=warm_batch_size, + full_model=False, + use_current_index=model_config.current_index, + page_attention=model_config.page_attention, + zactivate_len=model_config.zactivate_len, + decode_seq_length=inc_seq_len, + block_size=cfg.pa_config.block_size) # zactivate_len这里是否要加上zactivate_len + else: + + decode_inputs_list = warmup_func.get_warmup_inputs(seq_length=warm_seq_length, + batch_size=warm_batch_size, + full_model=False, + use_current_index=model_config.current_index, + valid_length=[act_len - 1], + page_attention=model_config.page_attention, + zactivate_len=model_config.zactivate_len, + model_type=model_config.model_type) + + decode_lite_inputs = [mslite.Tensor(np.ascontiguousarray(item)) for item in decode_inputs_list] + if rank_id == 0: + for item in decode_lite_inputs: + + print(1) + + all_models[i + 1].build_from_file(model_path.decode_model[rank_id], mslite.ModelType.MINDIR, context, + model_path.decode_ini[i]) + + if rank_id == 0: + model_in = all_models[i + 1].get_inputs() + for m_in in model_in: + print(1) + all_models[i + 1].predict(decode_lite_inputs) + + return all_models[0], all_models[1:] + + +def load_post_model(model_path, config_file, rank_id, device_id): + context = mslite.Context() + + context.ascend.device_id = device_id + context.ascend.rank_id = rank_id + context.ascend.provider = "ge" + context.target = ["Ascend"] + model = mslite.Model() + if not os.path.exists(model_path): + raise ValueError(f"load post-sampling model_path {model_path} not exists.") + + if not os.path.exists(config_file): + raise ValueError(f"load post-sampling post_model_ini {config_file} not exists.") + + model.build_from_file(model_path, mslite.ModelType.MINDIR, context, config_file) + return model + + +class DecodeParams: + def __init__(self, + do_sample: bool = True, + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + decode_index: int = -1, + current_index: int = 0, + valid_length: int = 0, + init_reset: bool = False, + ge_token: int = 0 + ): + self.do_sample = do_sample + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.repetition_penalty = repetition_penalty + self.decode_index = decode_index + self.current_index = current_index + self.valid_length = valid_length + self.init_reset = init_reset + self.ge_token = ge_token + + +""" +work_agent.proto实现, 供worker调用 +""" + + +class WorkAgent: + def __init__(self, rank_id, cfg: ServingConfig): + self.rank_id = rank_id + model_path = cfg.model_path + serving_config = cfg.serving_config + device_id = rank_id + serving_config.start_device_id + + if cfg.model_config.backend == "ge": + self.prefill, self.decode = load_model_for_ge(cfg, rank_id, device_id) + self.argmax_model = load_post_model(model_path.argmax_model, + model_path.post_model_ini, + rank_id, + device_id) + self.topk_model = load_post_model(model_path.topk_model, + model_path.post_model_ini, + rank_id, + device_id) + else: + self.mindspore_model = load_model_for_kbk(cfg, rank_id, device_id) + self.argmax_model = ArgmaxPost() + self.topk_model = temperature_TopK() + + self.shm_names = [] + self.init_reset = None + self.current_index = None + self.valid_length = None + self.tensor_shape = None + self.pre_input_ids = None + self.is_prefill = True + self.target = None + self.post_mode_list = None + self.input_length = None + self.targets = [] + self.kbk_targets = None + self.decode_params_map = {} + self.status = AgentStatus.unconnected + self.current_batch_size = None + self.config = cfg + self.basic_input_func = build_inputs(cfg.basic_inputs, module_type="basic_inputs") + self.extra_input_func = build_inputs(cfg.extra_inputs, module_type="extra_inputs") + + def _post_sampling_argmax_npu(self, outputs_np) -> np.ndarray: + """ + Args: + outputs_np: np.ndarray or ms.Tensor, (bs, 1, vocab_size) + """ + post_inputs = self.argmax_model.get_inputs() + if isinstance(outputs_np, np.ndarray): + post_inputs[0].shape = outputs_np.shape + post_inputs[0].set_data_from_numpy(outputs_np) + else: + post_inputs[0].shape = outputs_np.shape + post_inputs[0] = outputs_np + post_sampling_out = self.argmax_model.predict(post_inputs) + return post_sampling_out[0].get_data_to_numpy().astype(np.int32) + + @staticmethod + def _post_sampling_argmax_host(outputs) -> np.ndarray: + if isinstance(outputs, mslite.Tensor): + outputs = outputs.get_data_to_numpy() + outputs.reshape((outputs.shape[0], outputs.shape[-1])) + argmax_out = np.argmax(outputs, axis=-1) + return np.array([argmax_out]).astype(np.int32)[0] + + @staticmethod + def do_sample(decode_params, p_args, outs, targets, index, candidate_token_num: int = 1): + """ + Args: + decode_params: decode parameters for current client request + p_args: numpy.ndarray, index + outs: numpy.ndarray, probs + targets: batch targets after sampling + index: the batch index of current request + candidate_token_num: default top_p_num + """ + topp = decode_params.top_p + topk = decode_params.top_k + if topk > 100: + topk = 100 + outs = outs[:topk] + if topp < 1.0: + outs_ = np.cumsum(softmax_np(outs), axis=-1) + top_p_num = sum(outs_ < topp) + if top_p_num == 0: + top_p_num = candidate_token_num + outs = outs[:top_p_num] + p_args = p_args[:top_p_num] + + p = softmax_np(outs) + target_index = np.random.choice(len(p), p=p) + targets[index] = p_args[target_index] + + def _post_sampling_topk_npu(self, outputs_np, decode_index, prefill=True) -> np.ndarray: + """ + Args: + outputs_np: np.ndarray or ms.Tensor, (bs, 1, vocab_size) + """ + decode_params = self.decode_params_map[int(decode_index[0])] + self.targets.clear() + tempreture_ = np.array([decode_params.temperature], dtype=np.float32) + + post_inputs = self.topk_model.get_inputs() + + if isinstance(outputs_np, np.ndarray): + post_inputs[0].shape = outputs_np.shape + post_inputs[0].set_data_from_numpy(outputs_np) + + else: + post_inputs[0].shape = outputs_np.shape + post_inputs[0] = outputs_np + + post_inputs[1].shape = tempreture_.shape + post_inputs[1].set_data_from_numpy(tempreture_) + + post_sampling_out = self.topk_model.predict(post_inputs) + outs = post_sampling_out[0].get_data_to_numpy().astype(np.float16) + p_args = post_sampling_out[1].get_data_to_numpy() + thread_num = self.current_batch_size + targets = np.zeros((thread_num,), np.int32) + all_task = [pool.submit(self.do_sample, self.decode_params_map[decode_index[i]], p_args[i], outs[i], targets, i) + for i in range(thread_num)] + wait(all_task) + return targets + + def _post_sampling_topk_kbk(self, outputs_np, decode_index) -> np.ndarray: + """ + Args: + outputs_np: np.ndarray or ms.Tensor, (bs, 1, vocab_size) + """ + decode_params = self.decode_params_map[int(decode_index[0])] + self.targets.clear() + tempreture_ = np.array([decode_params.temperature], dtype=np.float32) + tempreture_t = Tensor(tempreture_, mindspore.float32) + post_sampling_out = self.topk_model(outputs_np, tempreture_t) + outs = post_sampling_out[0].asnumpy().astype(np.float16) + p_args = post_sampling_out[1].asnumpy() + thread_num = self.current_batch_size + targets = np.zeros((thread_num,), np.int32) + all_task = [pool.submit(self.do_sample, self.decode_params_map[decode_index[i]], p_args[i], outs[i], targets, i) + for i in range(thread_num)] + wait(all_task) + return targets + + def _get_seq_length(self, input_ids, is_prefill): + max_length = 0 + if not is_prefill: + if self.config.model_config.page_attention: + return self.config.pa_config.decode_seq_length + for item in input_ids: + if isinstance(item, list): + max_length = max(max_length, len(item)) + else: + max_length = max(max_length, 1) + if self.config.model_config.seq_type == 'dyn': + seq_length = max_length + elif len(self.config.model_config.seq_length) > 1: + seq_length = self._get_seq_length_dynmic_dinning(self.config.model_config.seq_length, max_length) + else: + if len(self.config.model_config.seq_length) == 0 and self.config.model_config.seq_type != 'dyn': + seq_length = 2048 + else: + seq_length = self.config.model_config.seq_length[0] + return seq_length + + @staticmethod + def _get_seq_length_dynmic_dinning(seq_list, seq_length): + for data in seq_list: + if seq_length < data: + return data + return seq_list[-1] + + @staticmethod + def _padding(origin_inputs, seq_length, default_padding_values): + pad_ids = list() + for item in origin_inputs: + pad_length = seq_length - len(item) + if pad_length < 0: + print(1) + pad_item = np.pad(item, (0, pad_length), 'constant', constant_values=default_padding_values) + pad_ids.append(pad_item) + return np.array(pad_ids) + + def _post_sampling_topk_host(self, outputs, decode_index, prefill): + """ + topk top-p in cpu, time-cost function, + """ + if isinstance(outputs, mslite.Tensor): + outputs = outputs.get_data_to_numpy() + outputs = np.reshape(outputs, (outputs.shape[0], outputs.shape[-1])) + thread_num = self.current_batch_size + targets = np.zeros((thread_num,), np.int32) + all_task = [pool.submit(post_sampling, np.array(item), self.decode_params_map[decode_index[i]], targets, i) + for i, item in enumerate(outputs)] + wait(all_task) + return targets + + def multi_thread_post_sampling(self, outputs_np, outputs_shm, decode_index_np, bs=1): + + self.targets.clear() + all_task = [pool.submit(self.do_post_sampling, outputs_np[i], outputs_shm, + decode_index_np[i], i) for i in range(bs)] + + for x in as_completed(all_task): + res = x.result() + self.targets.append(res) + return self.targets + + def get_consistent_batch(self, decode_index): + not_do_sample_list = [] + do_sample_list = [] + for index in decode_index: + do_sample_index = self.decode_params_map[index].do_sample + if do_sample_index is True: + do_sample_list.append(index) + else: + not_do_sample_list.append(index) + if len(do_sample_list) >= 1 and len(not_do_sample_list) >= 1: + for item in not_do_sample_list: + self.decode_params_map[item].top_k = 1 + do_sample = True + else: + do_sample = self.decode_params_map[decode_index[0]].do_sample + return do_sample + + def do_post_sampling(self, outputs_np, outputs_shm, output_logprob_shm, decode_index, prefill=True): + # 确保 outputs_np 是 numpy 数组 + # logger.info("333333333333333333333333333333333333") + # start_time = time.time() + # if isinstance(outputs_np, Tensor): + # logger.info("ttttttttttttttttttttttttttttttt") + # outputs_np = outputs_np.asnumpy() + ################ + # logger.info("do_post_sampling outputs_np shape is {}, value is{}".format(outputs_np.shape, outputs_np)) + do_sample = self.get_consistent_batch(decode_index) + if self.config.model_config.backend == "ge": + if self.config.serving_config.enable_host_post_sampling: + if not do_sample: + target = self._post_sampling_argmax_host(outputs_np) + target.reshape((self.current_batch_size,)) + target = np.squeeze(target, axis=1) + else: + target = self._post_sampling_topk_host(outputs_np, decode_index, prefill) + else: + if not do_sample: + target = self._post_sampling_argmax_npu(outputs_np) + else: + target = self._post_sampling_topk_npu(outputs_np, decode_index, prefill) + output_info = outputs_np.get_data_to_numpy() + else: + if not do_sample: + self.targets.clear() + target = self.argmax_model(outputs_np) + # target = self.argmax_model.construct(outputs_np) + else: + # print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnnn") + target = self._post_sampling_topk_kbk(outputs_np, decode_index) + ### raw + if isinstance(target, Tensor): + target = target.asnumpy() + output_info = outputs_np.asnumpy() + ### add + # if isinstance(target, np.ndarray): + # target = target + # output_info = outputs_np + # print("argmax time:") + # print(time.time()-start_time) + # print("target.dtype:") + # print(target.dtype) + ### 打印一下target和outputs_np的数据类型,对比一些原来的方案,数据类型是否有出入 + if self.rank_id == 0: + if prefill: + for index in decode_index: + # tmp = np.ndarray((index + self.current_batch_size,), dtype=np.int32, buffer=outputs_shm.buf) + tmp = np.ndarray((index + self.current_batch_size,), dtype=target.dtype, buffer=outputs_shm.buf) + tmp[index: index + self.current_batch_size] = target[:] + + logprob_list = [] + for idx, tag in enumerate(target): + logprob_list.append(output_info[idx][int(tag)]) + tmp_logprob = np.ndarray((index + self.current_batch_size,), dtype=np.float64, + buffer=output_logprob_shm.buf) + tmp_logprob[index: index + self.current_batch_size] = logprob_list[:] + self.targets[index: index + self.current_batch_size] = target[:] + else: + # tmp = np.ndarray((self.current_batch_size,), dtype=np.int32, buffer=outputs_shm.buf) + tmp = np.ndarray((self.current_batch_size,), dtype=target.dtype, buffer=outputs_shm.buf) + tmp[:] = target[:] + + logprob_list = [] + for idx, tag in enumerate(target): + if len(output_info.shape) == 2: + logprob_list.append(output_info[idx][int(tag)]) + else: + logprob_list.append(output_info[idx][0][int(tag)]) + tmp_logprob = np.ndarray((self.current_batch_size,), dtype=np.float64, buffer=output_logprob_shm.buf) + tmp_logprob[:] = logprob_list[:] + self.targets[:] = target[:] + + def model_choice_seq(self, act_len, decode_model_map): + if len(decode_model_map) == 1: + return decode_model_map[0] + act_len_list = self.config.model_config.zactivate_len + if len(act_len_list) != len(decode_model_map): + print(1) + model_index = act_len_list.index(act_len) + + return decode_model_map[model_index] + + def predict(self, shape_list=None, current_batch=None, batch_valid_flag=None): + self.status = AgentStatus.busy + tmp_shms = [] + start_time = time.time() + existing_shm0 = shared_memory.SharedMemory(name=self.shm_names[0]) + tmp_shms.append(existing_shm0) + + output_shm = shared_memory.SharedMemory(name=self.shm_names[5]) + tmp_shms.append(output_shm) + + output_logprob_shm = shared_memory.SharedMemory(name=self.shm_names[6]) + tmp_shms.append(output_logprob_shm) + + gen_params_id = 4 + gen_params_shm = shared_memory.SharedMemory(name=self.shm_names[gen_params_id]) + tmp_shms.append(gen_params_shm) + if self.is_prefill: + first_group = np.ndarray((shape_list[0]), dtype=np.int32, buffer=existing_shm0.buf) + current_index_ = first_group[:, shape_list[0][1] - 3: shape_list[0][1] - 2] + current_index = np.squeeze(current_index_, axis=-1) + + valid_length_ = first_group[:, shape_list[0][1] - 1: shape_list[0][1]] + if self.config.model_config.current_index or self.config.model_config.backend == "kbk": + valid_length = np.squeeze(valid_length_, axis=-1).astype(np.int64) + else: + valid_length = np.squeeze(valid_length_, axis=-1).astype(np.int32) + + input_ids = first_group[:, :shape_list[0][1] - 3] + gen_params_id = 1 # 改为1,正向取值,原始shape_list只有两个值,现在多加了两个 + shape_params = shape_list[gen_params_id] + gen_params = np.ndarray(shape_params, dtype=np.float16, buffer=gen_params_shm.buf) + + do_sample_list = gen_params[:, 0].astype(np.bool_) + top_p_list = gen_params[:, 1] + top_k_list = gen_params[:, 2].astype(np.int32) + temperature_list = gen_params[:, 3] + repetition_penalty_list = gen_params[:, 4] + decode_index_list = gen_params[:, 5].astype(np.int32) + # 添加baichuanPA block_tables_shape slot_mapping_shape + if self.config.model_config.page_attention: + block_tables_shape = shape_list[2] # 这里的shapeindex会不会变?? + slot_mapping_shape = shape_list[3] + + extra_input = [] + for i in range(1, len(shape_list) - 1): + existing_shm = shared_memory.SharedMemory(name=self.shm_names[i]) + tmp_shms.append(existing_shm) + # To Do np.int64 ? + extra_input.append(np.ndarray((shape_list[i]), dtype=np.int64, buffer=existing_shm.buf)) + + if self.config.model_config.backend == "ge": + # pa or static model type don't need 'act_len' parameter + if self.config.model_config.page_attention or ( + self.config.model_config.model_name == 'wizard_coder' and self.config.model_config.model_type == "static"): + extra_input = [] + else: + extra_input = self.extra_input_func.get_extra_inputs(input_ids, current_index, None, True, + valid_length, + zactivate_len=self.config.model_config.zactivate_len) + + self.current_batch_size = len(input_ids) + init_reset = [] + decode_index = [] + for i in range(self.current_batch_size): + decode_params = DecodeParams( + do_sample=bool(do_sample_list[i]), + top_p=top_p_list[i], + top_k=int(top_k_list[i]), + temperature=temperature_list[i], + repetition_penalty=repetition_penalty_list[i], + decode_index=int(decode_index_list[i]), + current_index=int(current_index[i]), + valid_length=int(valid_length[i]), + init_reset=False + ) + self.decode_params_map[decode_params.decode_index] = decode_params + init_reset.append(decode_params.init_reset) + decode_index.append(decode_params.decode_index) + init_reset = np.array(init_reset, dtype=np.bool_) + decode_index_np = np.array(decode_index, dtype=np.int64) + else: + # keep decode map size equal to current batch size + # extend + current_index = [] + valid_length = [] + init_reset = [] + decode_index = [] + self.current_batch_size = current_batch + current_batch_size = self.current_batch_size + if self.current_batch_size != len(batch_valid_flag): + batch_valid_flag.clear() + batch_valid_flag = [1 for _ in range(self.current_batch_size)] + before_batch_size = len(self.decode_params_map.keys()) + if before_batch_size < current_batch_size: + input_ids = np.ndarray((before_batch_size,), dtype=np.int32, buffer=output_shm.buf) + pad_input_id = self.config.model_config.end_token + add_length = self.current_batch_size - before_batch_size + addition_input_ids = np.array(add_length * [pad_input_id], dtype=np.int32) + input_ids = np.append(input_ids, addition_input_ids) + target_batch = self.current_batch_size + pad_key = list(self.decode_params_map.keys())[-1] + # padding_obj = self.decode_params_map[pad_key] + for j in range(target_batch): + if j not in self.decode_params_map: + padding_obj = copy.deepcopy(self.decode_params_map[pad_key]) + padding_obj.current_index = 0 + padding_obj.valid_length = 1 + padding_obj.decode_index = j + self.decode_params_map[j] = padding_obj + else: + # pop + while len(self.decode_params_map.keys()) > current_batch_size: + self.decode_params_map.popitem() + input_ids = np.ndarray((current_batch_size,), dtype=np.int32, buffer=output_shm.buf) + + self.decode_params_map = dict(sorted(self.decode_params_map.items(), key=lambda x: x[0])) + for key in self.decode_params_map.keys(): + decode_params = self.decode_params_map[key] + decode_params.current_index = decode_params.current_index + 1 + decode_params.valid_length = decode_params.valid_length + 1 + decode_params.init_reset = True # 修改原始代码bug + if batch_valid_flag[key] == 1: + current_index.append(decode_params.current_index) + valid_length.append(decode_params.valid_length) + else: + current_index.append(0) + valid_length.append(1) + init_reset.append(decode_params.init_reset) + decode_index.append(decode_params.decode_index) + + if self.config.model_config.backend == "ge": + # pa or static model type don't need 'act_len' parameter + if self.config.model_config.page_attention or ( + self.config.model_config.model_name == 'wizard_coder' and self.config.model_config.model_type == "static"): + extra_input = [] + else: + extra_input = self.extra_input_func.get_extra_inputs(input_ids, current_index, None, False, + valid_length, + zactivate_len=self.config.model_config.zactivate_len) + + current_index = np.array(current_index, dtype=np.int32) + if self.config.model_config.current_index or self.config.model_config.backend == "kbk": + valid_length = np.array(valid_length, dtype=np.int64) + else: + valid_length = np.array(valid_length, dtype=np.int32) + init_reset = np.array(init_reset, dtype=np.bool_) + decode_index_np = np.array(decode_index, dtype=np.int64) + input_ids = input_ids.reshape((-1, 1)) + # 加入PA特性 + if self.config.model_config.page_attention: + block_tables_shape = shape_list[0] + slot_mapping_shape = shape_list[1] + + block_tables_np = None + slot_mapping_np = None + if self.config.model_config.page_attention: + block_tables_shm = shared_memory.SharedMemory(name=self.shm_names[7]) # 这里的共享内存index要改 + slot_mapping_shm = shared_memory.SharedMemory(name=self.shm_names[8]) + block_tables_np = np.ndarray((block_tables_shape), dtype=np.int32, buffer=block_tables_shm.buf) + slot_mapping_np = np.ndarray((slot_mapping_shape), dtype=np.int32, buffer=slot_mapping_shm.buf) + + + if self.config.model_config.backend == "ge": + if self.config.model_config.page_attention: + if self.is_prefill: + tmp_in = [input_ids, valid_length, slot_mapping_np] + else: + tmp_in = [input_ids, valid_length, block_tables_np, slot_mapping_np] + else: + tmp_in = self.basic_input_func.get_inputs(input_ids, current_index, init_reset, valid_length, + self.config.model_config.current_index, decode_index_np, + self.config.model_config.model_type) + if len(extra_input) > 0: + tmp_in.extend(extra_input) + + for tmp in tmp_in: + print(1) + + outputs = self.predict_for_ge(extra_input, start_time, tmp_in) + else: + seq_length = self._get_seq_length(input_ids, False) + # init kbk_targets, shape(current_batch, seq_length), default value: self.config.model_config.pad_token_id + if self.kbk_targets is None: + decode_batch_size = self.config.model_config.decode_batch_size[0] + + self.kbk_targets = np.full((decode_batch_size, seq_length), self.config.model_config.pad_token_id) + + + + # decode 时,先将 shape 与 prefill 改为一致 + if input_ids.shape[1] == 1: + # input_ids = np.concatenate((input_ids, np.zeros((input_ids.shape[0], seq_length - 1))), axis=1) + input_ids = np.pad(input_ids,((0,0),(0,seq_length - 1)),'constant',constant_values = (0,0)) # liuyang + + # 遍历decode_index + for idx, index in enumerate(decode_index): + index = int(decode_index[0]) + + if self.is_prefill: + self.kbk_targets[index] = input_ids[idx] + else: + current_index_value = int(current_index[idx]) + self.kbk_targets[index][current_index_value:current_index_value + 1] = input_ids[idx][:1] + input_ids[idx] = self.kbk_targets[index] + + outputs = self.predict_for_kbk(current_index, input_ids, valid_length, block_tables_np, slot_mapping_np) + + # post_time = time.time() + if self.rank_id == 0: + # multi_thread_time = time.time() + if self.is_prefill: + self.do_post_sampling(outputs, output_shm, output_logprob_shm, decode_index_np, prefill=True) + else: + self.do_post_sampling(outputs, output_shm, output_logprob_shm, decode_index_np, prefill=False) + + self.status &= ~AgentStatus.busy + return self.targets, tmp_shms + + def predict_for_ge(self, extra_input, start_time, tmp_in): + # 调用ms lite进行推理 + if len(extra_input) > 0: + model = self.prefill if self.is_prefill else self.model_choice_seq(len(extra_input[0]), self.decode) + else: + model = self.prefill if self.is_prefill else self.decode[0] + lite_inputs = [mslite.Tensor(np.ascontiguousarray(item)) for item in tmp_in] + # predict_time = time.time() + if self.config.model_config.model_name == 'wizard_coder' and self.config.model_config.model_type == "static": + if self.is_prefill: + init_reset_ms_tensor = mslite.Tensor(np.array([False], np.bool_)) + else: + init_reset_ms_tensor = mslite.Tensor(np.array([True], np.bool_)) + outputs_list = model.predict((lite_inputs[0], lite_inputs[1], init_reset_ms_tensor, lite_inputs[2])) + else: + outputs_list = model.predict(lite_inputs) + + outputs = outputs_list[0] + return outputs + + def predict_for_kbk(self, current_index, input_ids, valid_length, block_tables, slot_mapping): + # 封装调用模型参数 + model_kwargs = {"current_index": current_index} + model_inputs = self.mindspore_model.prepare_inputs_for_generation(input_ids, **model_kwargs) + # 调用mindformers进行推理 + # predict_time = time.time() + if self.mindspore_model.config.use_past: + if self.is_prefill: + self.mindspore_model.is_first_iteration = True + res, current_index = self.mindspore_model.forward(input_ids=input_ids, + valid_length_each_example=valid_length, + generation_config=self.mindspore_model.config, + block_tables=block_tables, + slot_mapping=slot_mapping, + prefill=self.is_prefill, + **model_kwargs) + else: + res = self.mindspore_model(**model_inputs) + + outputs = res[0] if isinstance(res, tuple) else res + return outputs + + +# def start_agent_socket_server(i, cfg: ServingConfig, startup_queue): +# """启动agent进程, 由_agent_process进行调用, 创建agent进程""" +# if IMPORT_LITE_FAILED: +# print(1) +# work_agent = WorkAgent(i, cfg) + +# agent_ports = cfg.serving_config.agent_ports +# agent_ip = cfg.serving_config.agent_ip +# agent_address = (agent_ip, agent_ports[i]) + + +# parent_process = psutil.Process(os.getppid()) +# server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +# server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +# server.bind(agent_address) +# server.listen(50) + +# startup_queue.put(i) + +# # 绑定method + + +# while True: +# if not parent_process.is_running(): +# print(1) +# server.close() +# return + +# conn, client_addr = server.accept() +# # todo workagent = WorkAgent(config) +# while True: +# if not parent_process.is_running(): +# print(1) +# server.close() +# return +# try: +# data = conn.recv(4096) +# if not data: +# break +# data = data.decode() +# # worker 和 agent建联 +# if data.startswith('#'): +# if work_agent.status & AgentStatus.unconnected == AgentStatus.unconnected: +# data = data[1:] +# work_agent.shm_names = data.split(",") +# work_agent.status = AgentStatus.connected + +# conn.sendall("succes".encode()) +# else: +# conn.sendall("failed".encode()) +# elif data.startswith('*'): +# # 全量推理 +# work_agent.is_prefill = True +# data = data[1:] +# shape_strs = data.split(",") +# input_shapes = [] +# for shape_str in shape_strs: +# shape = list(map(int, shape_str.split(" "))) +# input_shapes.append(shape) +# _, _ = work_agent.predict(shape_list=input_shapes) +# if i == 0: +# conn.sendall("1".encode()) +# elif data.startswith('a'): +# # 增量推理 +# decode_data = data.split('_') +# # 增加PA的判断 +# current_batch_dyn = int(decode_data[-4]) if cfg.model_config.page_attention else int( +# decode_data[-2]) +# batch_valid_flag = [] +# batch_valid = decode_data[-3] if cfg.model_config.page_attention else decode_data[-1] +# for ele in batch_valid.split(" "): +# batch_valid_flag.append(int(ele)) +# # 增加 block_tables和slot_mapping 的 shape +# input_shapes = [] +# if cfg.model_config.page_attention: +# for shape_str in [decode_data[-2], decode_data[-1]]: +# shape = list(map(int, shape_str.split(" "))) +# input_shapes.append(shape) +# work_agent.is_prefill = False +# _, _ = work_agent.predict(current_batch=current_batch_dyn, batch_valid_flag=batch_valid_flag, +# shape_list=input_shapes) +# if i == 0: +# conn.sendall("1".encode()) +# elif data.startswith('e'): +# # worker退出获取agent状态,free状态下才允许退出 +# if work_agent.status & AgentStatus.busy == AgentStatus.busy: +# conn.sendall("busy".encode()) +# else: +# work_agent.status = AgentStatus.unconnected +# conn.sendall("free".encode()) +# elif data.startswith('r'): +# # reset agents status +# work_agent.status = AgentStatus.unconnected +# conn.sendall("succes".encode()) +# except ConnectionResetError: +# break +# except RuntimeError: +# conn.sendall("2".encode()) +# conn.close() + +def start_agent_socket_server(i, cfg: ServingConfig, startup_queue): + logging.basicConfig(level=logging.ERROR, + filename=f"./output/agent_{i}.log", + filemode='w', + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') + """启动agent进程, 由_agent_process进行调用, 创建agent进程""" + if IMPORT_LITE_FAILED: + logging.warning("import mindspore_lite failed, using kbk backend.") + work_agent = WorkAgent(i, cfg) # 创建一个WorkAgent实例,传入当前agent的索引和配置。 + + agent_ports = cfg.serving_config.agent_ports + agent_ip = cfg.serving_config.agent_ip + agent_address = (agent_ip, agent_ports[i]) + # 设置当前agent的地址(IP和端口)。 + print(agent_address) + + parent_process = psutil.Process(os.getppid()) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(agent_address) + server.listen(50) # 开始监听连接,允许最多50个待处理连接 + + startup_queue.put(i) + + # 绑定method + # print("start agent socket server in rank{}".format(i), flush=True) + # logging.info("Agent socket server started on {}".format(agent_address)) + + task_queue = queue.PriorityQueue() + + def handle_client(conn): + while True: + if not parent_process.is_running(): + logging.warning( + f"detect parent pid={parent_process.pid} has exited, child begin to exit") + conn.close() + return + + try: + data = conn.recv(4096) + if not data: + break + data = data.decode() + # logging.debug(f"Data received: {data}") + + if data.startswith('#') or data.startswith('*') or data.startswith('e') or data.startswith('r'): + priority = 0 # 高优先级 + else: + priority = 1 # 低优先级 + + task_queue.put((priority, data, conn)) + # logging.info(f"Task added to queue with priority {priority}: {data}") + + except ConnectionResetError: + break + except RuntimeError as e: + logging.error(f"Runtime error: {e}") + conn.sendall("2".encode()) + break + + def process_tasks(): + while True: + priority, data, conn = task_queue.get() + # logging.info(f"Processing task with priority {priority}: {data}") + + if data.startswith('#'): + if work_agent.status & AgentStatus.unconnected == AgentStatus.unconnected: + data = data[1:] + work_agent.shm_names = data.split(",") + work_agent.status = AgentStatus.connected + # logging.info("Connected successfully") + conn.sendall("success".encode()) + else: + # logging.info("Connection failed") + conn.sendall("failed".encode()) + + elif data.startswith('*'): + # 全量推理 + work_agent.is_prefill = True + data = data[1:] + shape_strs = data.split(",") + input_shapes = [] + for shape_str in shape_strs: + shape = list(map(int, shape_str.split(" "))) + input_shapes.append(shape) + _, _ = work_agent.predict(shape_list=input_shapes) + if i == 0: + conn.sendall("1".encode()) + + elif data.startswith('a'): + # 增量推理 + decode_data = data.split('_') + # 增加PA的判断 + current_batch_dyn = int(decode_data[-4]) if cfg.model_config.page_attention else int( + decode_data[-2]) + batch_valid_flag = [] + batch_valid = decode_data[-3] if cfg.model_config.page_attention else decode_data[-1] + for ele in batch_valid.split(" "): + batch_valid_flag.append(int(ele)) + # 增加 block_tables和slot_mapping 的 shape + input_shapes = [] + if cfg.model_config.page_attention: + for shape_str in [decode_data[-2], decode_data[-1]]: + shape = list(map(int, shape_str.split(" "))) + input_shapes.append(shape) + work_agent.is_prefill = False + _, _ = work_agent.predict(current_batch=current_batch_dyn, batch_valid_flag=batch_valid_flag, + shape_list=input_shapes) + if i == 0: + conn.sendall("1".encode()) + elif data.startswith('e'): + if work_agent.status & AgentStatus.busy == AgentStatus.busy: + # logging.info("Agent is busy") + conn.sendall("busy".encode()) + # else: + work_agent.status = AgentStatus.unconnected + # logging.info("Agent is free") + conn.sendall("free".encode()) + + elif data.startswith('r'): + work_agent.status = AgentStatus.unconnected + # logging.info("Reset successful") + conn.sendall("success".encode()) + + threading.Thread(target=process_tasks, daemon=True).start() + + while True: + if not parent_process.is_running(): + logging.warning(f"detect parent pid={parent_process.pid} has exited, child begin to exit") + server.close() + return + conn, client_addr = server.accept() + # logging.info(f"Connection accepted from {client_addr}") + threading.Thread(target=handle_client, args=(conn,), daemon=True).start() + +def handler(sig_num, addition): + os.killpg(os.getpgid(os.getpid()), signal.SIGKILL) + + +def startup_agents(config, startup_queue): + signal.signal(signal.SIGTERM, handler) + signal.signal(signal.SIGINT, handler) + agent_ports = config.serving_config.agent_ports + subprocess_list = [] + # log_dir = os.path.join(os.getcwd(), "output") + # if not os.path.exists(log_dir): + # os.mkdir(log_dir) + for i in range(len(agent_ports)): + p = Process(target=start_agent_socket_server, args=(i, config, startup_queue)) + p.start() + subprocess_list.append(p) + listen_agents_after_startup(subprocess_list) diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/llama_7b_kbk_pa_dyn.yaml" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/llama_7b_kbk_pa_dyn.yaml" new file mode 100644 index 00000000..a468078c --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/llama_7b_kbk_pa_dyn.yaml" @@ -0,0 +1,46 @@ +model_config: + model_name: 'llama_7b' + # max_generate_length: 600 ##快几秒 + max_generate_length: 4096 + end_token: 2 + seq_length: [4096] + vocab_size: 32000 + prefill_batch_size: [1] + # decode_batch_size: [1] + decode_batch_size: [128] + zactivate_len: [512, 1024, 2048, 4096] + model_type: 'dyn' + seq_type: 'static' + batch_waiting_time: 0.0 + decode_batch_waiting_time: 0.0 + batching_strategy: 'continuous' + current_index: False + page_attention: True + model_dtype: "DataType.FLOAT32" + pad_token_id: 0 + backend: 'kbk' # 'ge' + model_cfg_path: '/home/ma-user/work/mindformers/configs/llama2/predict_llama2_7b.yaml' + +serving_config: + agent_ports: [16002] + start_device_id: 0 + server_ip: '127.0.0.1' + server_port: 8835 + +pa_config: + num_blocks: 1024 + block_size: 16 + decode_seq_length: 4096 + +tokenizer: + type: LlamaTokenizer + vocab_file: '/home/ma-user/work/checkpoint_download/llama2/tokenizer.model' + +basic_inputs: + type: LlamaBasicInputs + +extra_inputs: + type: LlamaExtraInputs + +warmup_inputs: + type: LlamaWarmupInputs \ No newline at end of file diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/llm_server_post.py" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/llm_server_post.py" new file mode 100644 index 00000000..c3d65494 --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/llm_server_post.py" @@ -0,0 +1,151 @@ +import asyncio +import logging +import subprocess + +from mindspore_serving.master.master import AsyncMaster +from mindspore_serving.master.response_async_queue import AsyncResultsOfOneRequest +from mindspore_serving.master.utils import ResponseOutput, ModelInfo +from mindspore_serving.master.request_resister_engine import RequestEngine +from mindspore_serving.config.config import ServingConfig + + +# from mindspore_serving.serving_utils.register import import_all_modules_for_register + +# import_all_modules_for_register() + + +class LLMServer: + """ + request_queue(FIFO): add request into a async queue, and monitor request status(is_finished), + mapping inference result of each iteration to corresponding request + result queue(used in stream return). + master: Continuously getting unfinished request from request_queue, conducting batch strategy, + and doing one step inference using ms-lite, after get result of one iteration, client + get stream inference result from request_queue, and update to request_queue. + """ + + def __init__(self, config: ServingConfig): + self.request_engine = RequestEngine() + self.background_loop = None + self.master = AsyncMaster(config, self.request_engine) # liuyang + # self.master = AsyncMaster(config) + self.status = 0 + self.config = config + + @property + def is_running(self) -> bool: + return self.background_loop is not None + + async def run_loop(self): + while self.status: + await self.step() + await asyncio.sleep(0) + + # def start_background_loop(self) -> None: + # # todo + # self.status = 1 + # """Start the background loop.""" + # if self.is_running: + # raise RuntimeError("Background loop is already running.") + # self.background_loop = asyncio.get_event_loop().create_task(self.run_loop()) + + ### liuyang + def start_background_loop(self) -> None: + # todo + self.status = 1 + """Start the background loop.""" + if self.is_running: + raise RuntimeError("Background loop is already running.") + self.background_loop = asyncio.get_event_loop().create_task(self.run_loop()) + asyncio.get_event_loop().create_task(self.master.handle_detokenization_loop()) + + + async def register_request(self, + request_id: str, + **add_request_info) -> AsyncResultsOfOneRequest: + # logging.debug("background loop {}".format(self.background_loop)) + if self.background_loop is None: + self.start_background_loop() + + res_stream = self.request_engine.register_request( + request_id, + **add_request_info) + return res_stream + + def _abort(self, request_id: str) -> None: + """Abort a request. + Args: + request_id: The unique id of the request. + """ + self.request_engine.abort_request(request_id) + + async def step(self): + # loop consuming from request_engine + if self.status == 0: + return + new_requests, finished_requests = self.request_engine.get_requests_from_register_pool() + for new_request in new_requests: + self.master.add_requests_to_schedule_pool(**new_request) + if finished_requests: + await self._master_abort(finished_requests) + request_outputs = await self.master.step_async() + # Put the outputs into the corresponding streams. + if request_outputs is not None: + for request_output in request_outputs: + self.request_engine.process_request_output(request_output) + + def get_total_tokens(self): + return self.master.get_number_of_total_tokens() + + def get_bs_current(self): + return self.master.get_current_batch() + + def get_queue_current(self): + return self.master.get_current_requestes_nums() + + async def generate_answer( + self, + request_id: str, + **add_request_info + ) -> ResponseOutput: + + # Preprocess the request. + try: + res_stream = await self.register_request(request_id, **add_request_info) + + async for request_output in res_stream: + yield request_output + + except Exception as e: + # If there is an exception, abort the request. + self._abort(request_id) + raise e + + async def _master_abort(self, request_ids): + self.master.abort_request(request_ids) + + def stop(self): + # 1. stop background + self.status = 0 + self.master.stop() + + def get_dockerId(self): + p = subprocess.Popen("cat /proc/self/cgroup | grep /docker | head -1 | cut -d/ -f3", shell=True, + stdout=subprocess.PIPE) + out = p.stdout.read() + id = str(out, 'utf-8') + return id + + def get_serverd_model_info( + self + ) -> ModelInfo: + max_seq_length = int(self.config.model_config.seq_length[-1]) + max_decode_batch_size = int(self.config.model_config.decode_batch_size[-1]) + docker_id = self.get_dockerId() + serverd_model_info = ModelInfo(docker_label=docker_id, + max_batch_total_tokens=max_seq_length * max_decode_batch_size, + max_concurrent_requests=self.master.get_current_requestes_nums(), + max_input_length=max_seq_length, max_total_tokens=max_decode_batch_size, + model_dtype=self.config.model_config.model_dtype, + model_id=self.config.model_config.model_name) + return serverd_model_info diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/master.py" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/master.py" new file mode 100644 index 00000000..d6dda1f7 --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/master.py" @@ -0,0 +1,535 @@ +from tabnanny import check +from typing import List, Optional, Tuple +import copy +import time +import random +import logging + +from mindspore_serving.serving_utils.entry import EntryData, EntryMetaData, EntryStatus +from .utils import Counter, ResponseOutput + +from mindspore_serving.schedule.schedule import Schedule +from mindspore_serving.worker.worker import Worker +from mindspore_serving.config.config import ServingConfig +from mindspore_serving.models.build_tokenizer import build_tokenizer +from mindspore_serving.schedule.cache_engine import ServingBlockMemPool +from mindspore_serving.serving_utils.constant import * +from mindformers.mindformer_book import MindFormerBook + +Eps = 30 + + +class BatchTokenIdOut: + def __init__(self,output:str = None,entry_metadata_list = None,index_list = None, skip_inference = False): + + self.output = output + self.entry_metadata_list = entry_metadata_list + self.index_list = index_list + self.skip_inference = skip_inference + +class Master: + def __init__(self, + config: ServingConfig): + self.config = config + self.tokenizer = None + self.counter = Counter() + self.worker = Worker(config) + self.scheduler = Schedule(config) + + self.is_running = False + self._init_workers() + self._counter_of_token = 0 + self._init_tokenizer() + self.decode_cache = {} + if self.config.model_config.page_attention: + self._init_mem_pool() # PA + + # PA + + def _init_mem_pool(self): + ServingBlockMemPool.init(self.config.pa_config.num_blocks, self.config.pa_config.block_size) + + def _init_tokenizer(self): + self.tokenizer = build_tokenizer(self.config) + if self.tokenizer is None: + logging.error('load tokenizer failed!') + # logging.debug(f'self.tokenizer is {self.tokenizer}') + + def _init_workers(self): + self.worker._init_worker() + + def _schedule(self) -> Tuple[List[EntryMetaData], int]: + return self.scheduler.schedule() + + def get_number_of_total_tokens(self): + return self._counter_of_token + + def _detokenizer(self, tokens: List[int]) -> List[str]: + """ + tokens is results of post-sampling module. + output: texts list of batch + """ + texts = [] + for token in tokens: + token_input = [token] + text = self.tokenizer.decode(token_input, skip_special_tokens=True) + # logging.debug(f'tokenizer decode result is {text}, token id is {token}') + texts.append(text) + return texts + + def _llama_detokenizer(self, outputs): + str_outputs = [] + batch_size = len(outputs) + before_batch_size = len(self.decode_cache.keys()) + if batch_size > before_batch_size: + for i in range(before_batch_size, batch_size): + self.decode_cache[i] = [] + else: + while len(self.decode_cache.keys()) > batch_size: + self.decode_cache.popitem() + for i in range(batch_size): + self.decode_cache[i].append(outputs[i]) + new_text = self.tokenizer.decode(self.decode_cache[i], skip_special_tokens=True) + if not new_text.endswith("�"): + begin_token = self.tokenizer._convert_id_to_token(self.decode_cache[i][0]) + if begin_token == '<0x0A>': + begin_token = '\n' + elif '\u2581' in begin_token: + begin_token = ' ' + else: + begin_token = '' + + str_outputs.append(begin_token + new_text) + self.decode_cache[i] = [] + else: + str_outputs.append('') + return str_outputs + + def _llama_detokenizer_function(self, index, entry_metadata_list, skip_special_tokens=True): + + prefix_index = entry_metadata_list[index].get_entry_data().prefix_index + read_index = entry_metadata_list[index].get_entry_data().read_index + if entry_metadata_list[index].get_entry_data().get_status() != EntryStatus.RUNNING: + return "" + all_outputs_ids = entry_metadata_list[index].get_entry_data().get_output_token() + + prefix_text = self.tokenizer.decode(all_outputs_ids[prefix_index: read_index], + skip_special_tokens=skip_special_tokens) + + new_text = self.tokenizer.decode(all_outputs_ids[prefix_index:], skip_special_tokens=skip_special_tokens) + ########### add liuyang + # if prefix_index==0 and new_text.endswith("\n"): + # new_text = "\n" + # all_text = self.tokenizer.decode(all_outputs_ids, skip_special_tokens=False) + # txt_file = open('/home/ma-user/work/llm-serving/all_text.txt', 'a', encoding="utf-8") + # if prefix_index==0: + # txt_file.write(all_text) + # txt_file.write("\n") + ### read_index,每次保留一个阅读token + #### 这里加打印不行,看看采用np.save的形式 + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + new_text = new_text[len(prefix_text):] + entry_metadata_list[index].get_entry_data().prefix_index = read_index + entry_metadata_list[index].get_entry_data().read_index = len(all_outputs_ids) + ############### + # logging.info("new_text tttttttttttttttttttttttt {}".format(new_text)) + # if prefix_index==0 and new_text.endswith("\n"): + # return "\n" + # txt_file = open('/home/ma-user/work/llm-serving/new_text.txt', 'a', encoding="utf-8") + # if prefix_index==0 and new_text.endswith("\n") and len(new_text)>2: + # txt_file.write(new_text[0:-2]) + # txt_file.write("\n") + return new_text + else: + return "" + + def _llama_detokenizer_v2(self, + outputs, + entry_metadata_list, + index_list=None, + skip_special_tokens=True): + str_outputs = [] + # prompt + if index_list is not None: + for index in index_list: + str_outputs.append(self._llama_detokenizer_function(index, entry_metadata_list, skip_special_tokens)) + return str_outputs + # decode + ### 遍历整个batch,返回各个请求的新生成文本 + for index, output in enumerate(outputs): + str_outputs.append(self._llama_detokenizer_function(index, entry_metadata_list, skip_special_tokens)) + return str_outputs + + def _check_error_code(self, output_token): + error_code_list = [-1, -202, -203] + if output_token in error_code_list: + return True + return False + + def _postprocess(self, + outputs: List[tuple], + entry_metadata_list: List[EntryMetaData], + index_list: List[int] = None, + skip_inference=False) -> List[ResponseOutput]: + + end_token = self.config.model_config.end_token # for debug + + #### 首token是无效token,更新状态后返回,不要解码 liuyang + if len(outputs)==0 or outputs==[]: + self.scheduler.upate_entries_after_one_step_after_prefill(end_token, index_list) + return None + + output_tokens = [] + output_logprob = [] + ### 整个batch,一个迭代的输出, 一个迭代输出一个token + for output_tup in outputs: + output_tokens.append(output_tup[0]) + output_logprob.append(output_tup[1]) + + self.scheduler.upate_entries_after_one_step(output_tokens, end_token, index_list) + str_outputs = [''] * len(output_tokens) + if (self.config.model_config.model_name.startswith( + 'llama') or self.config.model_config.model_name == 'wizard_coder') and not self._check_error_code( + output_tokens[0]): + # str_outputs = self._llama_detokenizer(outputs) + str_outputs = self._llama_detokenizer_v2(output_tokens, entry_metadata_list, + index_list, skip_special_tokens=True) + + elif self.config.model_config.model_name in ( + 'internlm_7b', 'baichuan2pa', 'gpt2') and not self._check_error_code(output_tokens[0]): + str_outputs = self._detokenizer(output_tokens) + self._counter_of_token += len(output_tokens) + # logging.debug("target is {}, str_outputs is {}".format(outputs, str_outputs)) + # logging.debug("current total token numbers is {}".format(self._counter_of_token)) + # generating output + results: List[ResponseOutput] = [] + # prompt result + if index_list is not None: + # idx: index_list and outputs data index, index: batch list index. + for idx, index in enumerate(index_list): + ### liuyang + # if entry_metadata_list[index].get_entry_data().get_output_len()==1: + # continue + + if entry_metadata_list[index].entry_data.status == EntryStatus.PADDING_INVAILED: + # logging.debug(f'generate a invalid token, index in batch is {index}') + continue + if output_tokens[0] == INPUT_OUT_OF_TOKEN[0]: + # logging.debug(f'input out of range, index in batch is {index}') + results.append(ResponseOutput.generate_result(output_tokens[idx], + 0, + entry_metadata_list[index], + str_outputs[idx], + end_token, reason='Error202: prompt out of range')) + return results + + if output_tokens[0] == INPUT_EMPTY_TOKEN[0]: + # logging.debug(f'prompt token empty, index in batch is {index}') + results.append(ResponseOutput.generate_result(output_tokens[idx], + 0, + entry_metadata_list[index], + str_outputs[idx], + end_token, reason='Error203: prompt token empty')) + return results + + results.append(ResponseOutput.generate_result(output_tokens[idx], + output_logprob[idx], + entry_metadata_list[index], + str_outputs[idx], + end_token)) + # encode result + else: + for index, output_token in enumerate(output_tokens): + output_token_logprob = output_logprob[index] + if entry_metadata_list[index].entry_data.status == EntryStatus.PADDING_INVAILED: + # logging.debug(f'generate a invalid token, index in batch is {index}') + continue + + ### liuyang + # if entry_metadata_list[index].get_entry_data().get_output_len()==1: + # continue + + ####整个batch的 + results.append(ResponseOutput.generate_result(output_token, + output_token_logprob, + entry_metadata_list[index], + str_outputs[index], + end_token)) + return results + + def get_current_batch(self): + return self.scheduler.get_dyn_batch() + + def get_current_requestes_nums(self): + return self.scheduler.get_queue_len() + + def abort_request(self, + request_id: str) -> None: + self.scheduler.abort_entry(request_id) + + def add_requests_to_schedule_pool(self, + request_id: str, + prompt: Optional[str], + do_sample, + top_k, + top_p, + temperature, + repetition_penalty, + max_token_len): + # time_tokenizer = time.time() + prompt_token_ids = None + # logging.debug("request id add_requests_to_schedule_pool {}".format(request_id)) + # 加入baichuan + if self.config.model_config.model_name in ( + 'baichuan2pa', 'wizard_coder') or self.config.model_config.model_name.startswith('llama'): + prompt_token_ids = self.tokenizer.encode(prompt) + elif self.config.model_config.model_name == 'internlm_7b': + prompt_token_ids = self.tokenizer(prompt)['input_ids'][1:] + elif self.config.model_config.model_name in MindFormerBook.get_tokenizer_support_list(): + prompt_token_ids = self.tokenizer(prompt)['input_ids'] + else: + print('incorrect model_name') + logging.debug('incorrect model_name') + + # logging.info('tokenizer result prompt_token_ids is {}'.format(prompt_token_ids)) + # logging.info('tokenizer time is {}'.format((time.time() - time_tokenizer) * 1000)) + + # if prompt_token_ids is not None and + # Create the sequences. + entry_id = next(self.counter) + entry_data = EntryData(prompt_tokens=prompt_token_ids, + max_token_len=max_token_len, + do_sample=do_sample, + tok_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty) + block_size = 0 + if self.config.model_config.page_attention: + block_size = self.config.pa_config.block_size + entry_meta_data = EntryMetaData(page_attention=self.config.model_config.page_attention, + request_id=request_id, + is_prompt=True, + entry_data=entry_data, + entry_id=entry_id, + prompt=prompt, + block_size=block_size) + ### add liuyang + # txt_file = open('/home/ma-user/work/llm-serving/new_request.txt', 'a', encoding="utf-8") + # txt_file.write(request_id) + # txt_file.write("\n") + + # logging.debug("add request to schedule queue {}".format(entry_meta_data.request_id)) + self.scheduler.add_entrys(entry_meta_data) + + def step(self) -> List[ResponseOutput]: + # do inference + entry_metadata_list, batch_size = self._schedule() + # Execute the model. + # output: model infer out(token): + # output is batch_size * n_src_vocab + output = self._mock_run_workers_async(batch_size) + return self._postprocess(output, entry_metadata_list) + + def _mock_run_workers_async(self, batch_size: int): + outputs = [] + for i in range(batch_size): + output = random.randint(0, 32000) ### 生成了一个随机token + # output = 100 ##### liuyang + outputs.append(output) + time.sleep(0.15) + return outputs + +import asyncio +from mindspore_serving.master.request_resister_engine import RequestEngine +# class AsyncMaster(Master): +class AsyncMaster(Master): + def __init__( + self, + config: ServingConfig, + request_engine: RequestEngine + ): + super().__init__(config) + self.detokenizer_que = asyncio.Queue() + self.request_engine = request_engine + + def send_post_process(self,output,entry_metadata_list,index_list,skip_inference=False): + self.detokenizer_que.put_nowait( + BatchTokenIdOut(output,entry_metadata_list,index_list,skip_inference) + ) + + async def step_async(self) -> List[ResponseOutput]: + entries_metadata_list, current_batch_size = self._schedule() + valid_entry_len = 0 + for metadata in entries_metadata_list: + if metadata.entry_data.get_status() == EntryStatus.RUNNING or \ + metadata.entry_data.get_status() == EntryStatus.INPUT_OUTOFRANGE: + valid_entry_len += 1 + if valid_entry_len == 0: + return None + + output = await self._run_workers_async(current_batch_size, entry_metadata_list=entries_metadata_list) + return output + + def _get_prompt_batch_list(self, entry_metadata_list): + # PA取一个data进行prefill + if self.config.model_config.page_attention: + return self._check_prompt_predict_data_pa(entry_metadata_list) + ### 这里 index_list是指 整个running_request_list中为 prompt请求的下标 + input_entry_metadata_list, index_list = self._check_prompt_predict_data(entry_metadata_list) + prompt_data_count = len(input_entry_metadata_list) + + if prompt_data_count == 0: ### 说明当前的decode_batchsize个请求中没有 prompt请求,返回空列表 + return input_entry_metadata_list, index_list + # logging.debug("_get_prompt_batch_list prompt index_list {}, input_entry_metadata_list {}" + # .format(index_list, input_entry_metadata_list)) + + prefill_batch_size_list = self.config.model_config.prefill_batch_size + if prefill_batch_size_list is None or len(prefill_batch_size_list) == 0: + return [input_entry_metadata_list[0]], [index_list[0]] + else: # pure dyn + dyn_bach_size = prefill_batch_size_list[0] + if prompt_data_count > dyn_bach_size: + input_entry_metadata_list = input_entry_metadata_list[:dyn_bach_size] + index_list = index_list[:dyn_bach_size] + + return input_entry_metadata_list, index_list + + @staticmethod + def get_last_prompt_entry(entry_metadata_list): + for i in range(len(entry_metadata_list) - 1, -1, -1): + entry_meta_data = entry_metadata_list[i] + if entry_meta_data.is_prompt: + return entry_meta_data + + @staticmethod + def _get_prefill_padding_entry(index, entry_meta_data): + copy_entry_meta_data = copy.deepcopy(entry_meta_data) + copy_entry_meta_data.get_entry_data().set_status(EntryStatus.PADDING_INVAILED) + copy_entry_meta_data.get_entry_data().set_decode_index(index) + # logging.debug(f'add invalid request into prefill batch list, batch size index is {index}') + return copy_entry_meta_data + + @staticmethod + def _check_prompt_out_of_range_index_list(entry_metadata_list): + """check prompt out of range index list""" + out_of_range_index_list = [] + # for item in entries_metadata_list: + for index, item in enumerate(entry_metadata_list): + if not item.is_prompt or item.entry_data.status != EntryStatus.INPUT_OUTOFRANGE: + continue + + out_of_range_index_list.append(index) + return out_of_range_index_list + + @staticmethod + def _check_prompt_predict_data_pa(entry_metadata_list): + input_entry_metadata_list = [] + index_list = [] + for index, item in enumerate(entry_metadata_list): + if not item.is_prompt or item.entry_data.status == EntryStatus.INPUT_OUTOFRANGE: + continue + input_entry_metadata_list = [item] + index_list = [index] + break + return input_entry_metadata_list, index_list + + @staticmethod + def _check_prompt_predict_data(entry_metadata_list): + input_entry_metadata_list = [] + index_list = [] + for index, item in enumerate(entry_metadata_list): + if not item.is_prompt or item.entry_data.status == EntryStatus.INPUT_OUTOFRANGE: + continue + + input_entry_metadata_list.append(item) + index_list.append(index) + return input_entry_metadata_list, index_list + + @staticmethod + def _check_prompt_token_empty(entry_metadata_list, pad_token_id): + empty_list = [] + for index, item in enumerate(entry_metadata_list): + if item.get_entry_data().get_prompt_token() == None or item.get_entry_data().get_prompt_len() == 0: + item.get_entry_data().set_status(EntryStatus.EMPTY_PROMPT_TOKEN) + empty_list.append(index) + if pad_token_id in item.get_entry_data().get_prompt_token(): + item.get_entry_data().set_status(EntryStatus.EMPTY_PROMPT_TOKEN) + empty_list.append(index) + return empty_list + + async def _run_workers_async(self, current_batch_size, entry_metadata_list): + # e_t_e_time = time.time() + + prompt_token_empty_list = self._check_prompt_token_empty(entry_metadata_list, + self.config.model_config.pad_token_id) + # logging.debug("prompt token empty list index_list {}".format(prompt_token_empty_list)) + if len(prompt_token_empty_list) > 0: + self.send_post_process([INPUT_EMPTY_TOKEN], entry_metadata_list=entry_metadata_list, + index_list=prompt_token_empty_list, + skip_inference=True) + # return self._postprocess([INPUT_EMPTY_TOKEN], entry_metadata_list=entry_metadata_list, + # index_list=prompt_token_empty_list, + # skip_inference=True) + + # check prefill out of range data + out_of_range_index_list = self._check_prompt_out_of_range_index_list(entry_metadata_list) + # logging.debug("out of range prompt index_list {}".format(out_of_range_index_list)) + if len(out_of_range_index_list) > 0: + self.send_post_process([INPUT_OUT_OF_TOKEN], entry_metadata_list=entry_metadata_list, + index_list=out_of_range_index_list, + skip_inference=True) + # return self._postprocess([INPUT_OUT_OF_TOKEN], entry_metadata_list=entry_metadata_list, + # index_list=out_of_range_index_list, + # skip_inference=True) + + # filter prompt data batch list + ## entry_metadata_list就是schedule返回的decode_batchsize个请求 + ### 这里根据 prefill_batchsize从中取出prefill_batchsize个请求 + input_entry_metadata_list, index_list = self._get_prompt_batch_list(entry_metadata_list) + # logging.debug("_get_prompt_batch_list prompt index_list {}, input_entry_metadata_list {}" + # .format(index_list, input_entry_metadata_list)) + # prefill predict + if len(input_entry_metadata_list) > 0: ### 只要running_request_list存在刚进来的prompt, 就在先要进行prefill + # logging.debug('prefill len of input entry_metadata_list is {}'.format(len(input_entry_metadata_list))) + # predict + output = self.worker.predict(current_batch_size, entry_metadata_list=input_entry_metadata_list) + else: # decode predict + input_entry_metadata_list = entry_metadata_list + index_list = None + # logging.debug('decode len of input entry_metadata_list is {}'.format(len(input_entry_metadata_list))) + output = self.worker.predict(current_batch_size, entry_metadata_list=input_entry_metadata_list) + + self.send_post_process(output, entry_metadata_list=entry_metadata_list, index_list=index_list) + + # result = self._postprocess(output, entry_metadata_list=entry_metadata_list, index_list=index_list) + # logging.info('post_process_time time is {}'.format((time.time() - post_process_time) * 1000)) + # logging.info('e-to-e time is {}'.format((time.time() - e_t_e_time) * 1000)) + # return result + + async def _mock_run_workers_async(self, batch_size: int): + outputs = [] + for i in range(batch_size): + output = random.randint(0, 32000) + # output = 100 ##### liuyang + outputs.append(output) + return outputs + + def stop(self): + self.worker.stop() + + async def handle_detokenization_loop(self): + while True: + try: + recv_obj = await self.detokenizer_que.get() # BatchTokenIdOut + post_process_time = time.time() + request_outputs = self._postprocess(recv_obj.output, entry_metadata_list=recv_obj.entry_metadata_list, index_list=recv_obj.index_list,skip_inference=recv_obj.skip_inference) + logging.info('post_process_time time is {}'.format((time.time() - post_process_time) * 1000)) + + # Put the outputs into the corresponding streams. + if request_outputs is not None: + for request_output in request_outputs: + self.request_engine.process_request_output(request_output) + self.detokenizer_que.task_done() + except Exception as e: + print(e) diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/model_init_multimodel.py" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/model_init_multimodel.py" new file mode 100644 index 00000000..f84e1318 --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/model_init_multimodel.py" @@ -0,0 +1,534 @@ +import time +import logging +import numpy as np +from typing import List +import socket +from mindspore_serving.config.config import ServingConfig +from mindspore_serving.models.build_inputs import build_inputs + + +class BaseInputsOfInfer: + """ + BaseInputsOfInfer interface. + """ + + def get_inputs(self, model, **kwargs): + pass + + @staticmethod + def get_lite_tensor_list(inputs, model): + input_list = [] + for item in inputs: + if item is None: + continue + input_list.append(item) + lite_inputs = model.get_inputs() + for input_np, tensor in zip(input_list, lite_inputs): + tensor.set_data_from_numpy(input_np) + return lite_inputs + + +class CommonInputsOfInfer(BaseInputsOfInfer): + """ + common infer inputs of llm models. + """ + + def __init__(self): + pass + + # pylint: disable=W0221 + def get_inputs(self, input_ids=None, current_index=None, valid_length=None, + init_reset=None, is_first_iteration=True, **kwargs): + if not is_first_iteration: + inputs_tmp = [] + for i in range(len(current_index)): + current_index_tmp = int(current_index[i]) - i * input_ids.shape[1] # multibatch + # use numpy to slice array to avoid complie ascend slice op + inputs_tmp.append(input_ids[i][current_index_tmp:current_index_tmp + 1]) + input_ids = np.array(inputs_tmp, dtype=np.int32) + + inputs = [input_ids, current_index, init_reset, valid_length] + return inputs + + +class CommonInputsOfInferDyn(BaseInputsOfInfer): + """ + common infer inputs of llm models. + """ + + def __init__(self): + pass + + # pylint: disable=W0221 + def get_inputs(self, input_ids=None, current_index=None, valid_length=None, + init_reset=None, is_first_iteration=True, InputExtraList=[], **kwargs): + mask = InputExtraList[0] + freq_cos = InputExtraList[1] + freq_sin = InputExtraList[2] + if not is_first_iteration: + inputs_tmp = [] + for i in range(len(current_index)): + current_index_tmp = int(current_index[i]) - i * input_ids.shape[1] # multibatch + # use numpy to slice array to avoid complie ascend slice op + + inputs_tmp.append(input_ids[i][current_index_tmp:current_index_tmp + 1]) + input_ids = np.array(inputs_tmp, dtype=np.int32) + if is_first_iteration: + # mask, freq_cos, fre_sin + inputs = [input_ids, current_index, init_reset, valid_length, mask, freq_cos, freq_sin] + else: + inputs = [input_ids, current_index, init_reset, valid_length] + return inputs + + +class CustomInputsOfInfer(BaseInputsOfInfer): + """ + common infer inputs of llm models. + """ + + def __init__(self): + self.get_input_from_config = get_inputs_custom + + # pylint: disable=W0221 + def get_inputs(self, **kwargs): + return self.get_input_from_config(**kwargs) + + # print("inputs after get_inputs:{}".format(inputs)) + # lite_inputs = BaseInputsOfInfer.get_lite_tensor_list(inputs, model) + # return lite_inputs + inputs_custom = self.get_input_from_config(**kwargs) + if inputs_custom is None: + logging.error('custom inputs definited by customer is None,please check it in server config!') + return inputs_custom + + +class InputOfInfer: + """ + Input of llm model. + """ + MAPPING = { + "bloom": CommonInputsOfInfer, + "llama": CommonInputsOfInfer, + "glm2": CommonInputsOfInfer, + "common": CommonInputsOfInfer, + "llama_dyn": CommonInputsOfInferDyn, + "wizard_coder": CommonInputsOfInferDyn, + "internlm": CommonInputsOfInfer, + "baichuan2": CommonInputsOfInfer, + "custom": CustomInputsOfInfer + } + + @classmethod + def get_inputs(cls, model_name: str, **kwargs): + """ + Get input tensor list of mslite. + + Args: + model_name: str, model name. + + Returns: + tensor list of mslite. + """ + # name = "" + # if Baseconfig['input_function'] == 'custom': + # model_name = "custom" + # logging.debug('model name {}'.format(model_name)) + # if model_name not in InputOfInfer.MAPPING: + # for k in InputOfInfer.MAPPING: + # if model_name.startswith(k): + # name = k + # break + # if not name: + # logging.warning("Model name not in support maps.Common input format will be used to do inference.") + # name = "common" + # else: + # name = model_name + return InputOfInfer.MAPPING['common']().get_inputs(**kwargs) + + +class CommonWarp: + """ + common infer inputs of llm models. + """ + + def __init__(self): + pass + + # pylint: disable=W0221 + def get_warp_inputs(self, lite_inputs=None, **kwargs): + init = 0 + init_reset = [init for _ in range(Baseconfig.prefill_batch_size)] + + lite_inputs[2] = np.array(init_reset).reshape(Baseconfig.prefill_batch_size, 1).astype(np.int32) + + first_group = np.concatenate((lite_inputs[0], lite_inputs[1].reshape(Baseconfig.prefill_batch_size, 1), + lite_inputs[2], lite_inputs[3].reshape(Baseconfig.prefill_batch_size, 1)), axis=1) + second_group = [] + return first_group, second_group + + +class CommonWarpDyn: + """ + common infer inputs of llm models. + """ + + def __init__(self): + pass + + # pylint: disable=W0221 + def get_warp_inputs(self, lite_inputs=None, **kwargs): + init = 0 + init_reset = [init for _ in range(Baseconfig.prefill_batch_size)] + lite_inputs[2] = np.array(init_reset).reshape(Baseconfig.prefill_batch_size, 1).astype(np.int32) + + first_group = np.concatenate((lite_inputs[0], lite_inputs[1].reshape(Baseconfig.prefill_batch_size, 1), + lite_inputs[2], lite_inputs[3].reshape(Baseconfig.prefill_batch_size, 1)), axis=1) + + second_group = [] + for i in range(4, len(lite_inputs)): + second_group.append(lite_inputs[i]) + return first_group, second_group + + +class WarpInputOfInfer: + """ + Input of llm model. + """ + MAPPING = { + "bloom": CommonWarp, + "llama": CommonWarp, + "glm2": CommonWarp, + "common": CommonWarp, + "llama_dyn": CommonWarpDyn, + "wizard_coder": CommonWarpDyn, + "internlm": CommonWarp, + "baichuan2": CommonWarp, + } + + @classmethod + def get_warp_inputs(cls, model_name: str, **kwargs): + """ + Get warpping input tensor list of mslite. + + Args: + model_name: str, model name. + + Returns: + tensor list of mslite. + """ + name = "" + if model_name not in InputOfInfer.MAPPING: + for k in InputOfInfer.MAPPING: + if model_name.startswith(k): + name = k + break + if not name: + logging.warning("Model name not in support maps.Common input format will be used to do inference.") + name = "common" + else: + name = model_name + return WarpInputOfInfer.MAPPING[name]().get_warp_inputs(**kwargs) + + +class Singleton(object): + def __init__(self, cls): + self._cls = cls + self.uniqueInstance = None + + def __call__(self): + if self.uniqueInstance is None: + self.uniqueInstance = self._cls() + return self.uniqueInstance + + +""" +全局定义一个DisModel, 保存和agents的通信管道 +""" + + +@Singleton +class DisModel: + def __init__(self): + self.agent_stubs = [] + self.model_name = None + self.config = None + + def init(self, config, shm_names: List[str] = None): + self.config = config + agent_ip = config.serving_config.agent_ip + agent_ports = config.serving_config.agent_ports + model_name = config.model_config.model_name + print(f"agent_ports is {agent_ports}") + for port in agent_ports: + print("port ip is {}".format(port)) + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # socket是1对1的,设置超时机制,防止多个serving连接同一个LLM + client.settimeout(5) + client.connect((agent_ip, port)) + send_str = '#' + ",".join(str(element) for element in shm_names) + client.sendall(send_str.encode()) + data = client.recv(6, socket.MSG_WAITALL).decode() + print(data) + if data == "failed": + client.close() + for agent in self.agent_stubs: + agent.close() + raise ConnectionError("there exists another connected serving now, stop the previous serving at first") + self.agent_stubs.append(client) + client.settimeout(None) + # send shm_names + self.model_name = model_name + + @staticmethod + def reset_agent_status(config): + print("waiting to reset agents status") + agent_ip = config.serving_config.agent_ip + agent_ports = config.serving_config.agent_ports + for port in agent_ports: + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # socket是1对1的,设置超时机制,防止多个serving连接同一个LLM + client.settimeout(5) + client.connect((agent_ip, port)) + client.sendall("r".encode()) + data = client.recv(6, socket.MSG_WAITALL).decode() + print(data) + if data == "succes": + print("reset") + print("reset all agents!") + + def stop(self): + print("waiting worker to exit") + for item in self.agent_stubs: + cnt = 0 + while True or cnt < 1000: + item.sendall("e".encode()) + data = item.recv(4096).decode() + print(data) + if data == "free": + print("close socket") + item.close() + break + cnt += 1 + if cnt >= 1000: + print("agent is running now, failed to stop serving, try to stop later") + print("exit!") + + def get_predict_inputs(self, input_ids, current_index=None, + valid_length=None, init_reset=None, is_first_iteration=True, **kwargs): + """Get inputs of llm model for mslite.""" + return InputOfInfer.get_inputs(self.model_name, input_ids=input_ids, current_index=current_index, + valid_length=valid_length, init_reset=init_reset, + is_first_iteration=is_first_iteration, **kwargs) + + def get_model_inputs(self, input_ids, current_index=None, + valid_length=None, is_first_iteration=True, **kwargs) -> np.array: + if is_first_iteration: + init_reset = np.array([False]) + lite_inputs = self.get_predict_inputs(input_ids, current_index, + valid_length, init_reset, is_first_iteration, **kwargs) + + else: + init_reset = np.array([True]) + lite_inputs = self.get_predict_inputs(input_ids, current_index, + valid_length, init_reset, is_first_iteration, **kwargs) + + return lite_inputs + + def get_warp_inputs(self, lite_inputs=None, **kwargs): + """Get inputs of llm model for mslite.""" + return WarpInputOfInfer.get_warp_inputs(self.model_name, lite_inputs=lite_inputs, **kwargs) + + @staticmethod + def get_gen_parms_np(batch_size, dtype=np.float16, **kwargs): + do_sample_list = kwargs.pop("do_sample_list") + top_k_list = kwargs.pop("top_k_list") + top_p_list = kwargs.pop("top_p_list"), + temperature_list = kwargs.pop("temperature_list"), + repetition_penalty = kwargs.pop("repetition_penalty") + decode_index_list = kwargs.pop("decode_index_list") + + do_sample_np = np.array(do_sample_list).reshape(batch_size, 1).astype(dtype) + top_p_np = np.array(top_p_list).reshape(batch_size, 1).astype(dtype) + top_k_np = np.array(top_k_list).reshape(batch_size, 1).astype(dtype) + temperature_np = np.array(temperature_list).reshape(batch_size, 1).astype(dtype) + repetition_np = np.array(repetition_penalty).reshape(batch_size, 1).astype(dtype) + decode_index_np = np.array(decode_index_list).reshape(batch_size, 1).astype(dtype) + parms_np = np.concatenate((do_sample_np, top_p_np, top_k_np, temperature_np, repetition_np, decode_index_np), + axis=-1) + return parms_np + + def _assemble_pa_inputs(self, is_first_iteration, batch_valid_length: np.array, cache_engine_list, seq_length, + valid_batch_flag): + if is_first_iteration: + return self._assemble_pa_full_inputs(batch_valid_length, cache_engine_list, seq_length, valid_batch_flag) + else: + return self._assemble_pa_inc_inputs(batch_valid_length, cache_engine_list, seq_length, valid_batch_flag) + + def _assemble_pa_full_inputs(self, batch_valid_length: np.array, cache_engine_list, seq_length, valid_batch_flag): + block_size = cache_engine_list[0].block_size + max_num_blocks_per_seq = seq_length // block_size + + bs = len(valid_batch_flag) + block_tables = [] + slot_mapping = [] + for i in range(bs): + if valid_batch_flag[i]: + cache_engine_list[i].prepare_cache(batch_valid_length[i]) + # 预留出首个块,给冗余写用,全量需要这个 TODO:后续优化ReshapeAndCache逻辑,跳过冗余位置 + block_table = cache_engine_list[i].block_table + # padded_table = block_table + [ -1 for _ in range(max_num_blocks_per_seq - len(cache_engine_list[i].block_table) + 1)] + padded_table = block_table + [-1 for _ in + range(max_num_blocks_per_seq - len(cache_engine_list[i].block_table))] + block_tables.append(padded_table) + + slots = [block_table[k // block_size] * block_size + k % block_size for k in range(batch_valid_length[i])] + null_slot_idx = 0 + slots = slots + [null_slot_idx for _ in range(seq_length - batch_valid_length[i])] + slot_mapping = slot_mapping + slots + block_tables = np.array(block_tables, dtype=np.int32) + slot_mapping = np.array(slot_mapping, dtype=np.int32) + return block_tables, slot_mapping + + def _assemble_pa_inc_inputs(self, batch_valid_length: np.array, cache_engine_list, seq_length, valid_batch_flag): + block_size = cache_engine_list[0].block_size + max_num_blocks_per_seq = seq_length // block_size + bs = len(valid_batch_flag) + block_tables = [] + slot_mapping = [] + for i in range(bs): + if valid_batch_flag[i]: + cache_engine_list[i].prepare_cache(1) # 增量推理时,每个序列新增一个token。 + valid_length = cache_engine_list[i].num_token # - block_size + else: + valid_length = 1 + block_table = cache_engine_list[i].block_table + padded_table = block_table + [-1 for _ in + range(max_num_blocks_per_seq - len(cache_engine_list[i].block_table))] + block_tables.append(padded_table) + curent_idx = valid_length - 1 + + slots = [block_table[curent_idx // block_size] * block_size + curent_idx % block_size] + slot_mapping = slot_mapping + slots + block_tables = np.array(block_tables, dtype=np.int32) + slot_mapping = np.array(slot_mapping, dtype=np.int32) + return block_tables, slot_mapping + + def call(self, shms: List, input_ids, current_index, + valid_length, init_reset, is_first_iteration, valid_batch_flag, extra_inputs=None, + current_batch_size=None, **kwargs): + """kvcache infer""" + time_start = time.time() + # logging.debug("is prefill {}".format(is_first_iteration)) + decode_index_list = kwargs.get("decode_index_list") + # 加入pa + if self.config.model_config.page_attention: + cache_engine_list = kwargs.get("cache_engine_list") + seq_length = kwargs.get("seq_length") + if is_first_iteration: + lite_inputs = self.get_model_inputs(input_ids, current_index, valid_length, + is_first_iteration, extra_inputs=extra_inputs, **kwargs) + # 前4个array拼接成一个 + # init_reset变成[batch_size, 1] + # first_group, second_group = self.get_warp_inputs(lite_inputs=lite_inputs, **kwargs) + init = 0 + prefill_bs = len(input_ids) + init_reset = [init for _ in range(prefill_bs)] + lite_inputs[2] = np.array(init_reset).reshape(prefill_bs, 1).astype(np.int32) + + first_group = np.concatenate((lite_inputs[0], lite_inputs[1].reshape(prefill_bs, 1), + lite_inputs[2], lite_inputs[3].reshape(prefill_bs, 1)), axis=1) + shape_list = [] + first = np.ndarray(first_group.shape, dtype=first_group.dtype, buffer=shms[0].buf) + first[:] = first_group[:] + shape_list.append(first_group.shape) + + # 如果是prefill的话,需要将另外三个array也写到共享内存中 + second_group = [] + for i in range(4, len(lite_inputs)): + second_group.append(lite_inputs[i]) + # logging.debug("second_group {}".format(second_group)) + if len(second_group) != 0: + for j in range(len(second_group)): + # logging.debug("second_group index {}".format(j)) + item = np.ndarray(second_group[j].shape, dtype=second_group[j].dtype, buffer=shms[1 + j].buf) + item[:] = second_group[j][:] + shape_list.append(second_group[j].shape) + # mem_index = len(second_group) + params_np_dtype = np.float16 + params_np = self.get_gen_parms_np(prefill_bs, params_np_dtype, **kwargs) + # gen_index = max(3, mem_index) + + gen_params = np.ndarray(params_np.shape, dtype=params_np_dtype, buffer=shms[4].buf) + gen_params[:] = params_np[:] + + shape_list.append(params_np.shape) + + shape_strs = [] + for shape in shape_list: + shape_str = " ".join(str(element) for element in shape) + shape_strs.append(shape_str) + shapes_str = "*" + ",".join(element for element in shape_strs) + else: + # logging.debug("valid_batch_flag in decode is {}".format(valid_batch_flag)) + batch_flag_str = " ".join(str(element) for element in valid_batch_flag) + shapes_str = "a" + '_' + str(current_batch_size) + '_' + batch_flag_str + + # 加入pa + if self.config.model_config.page_attention: + block_tables, slot_mapping = self._assemble_pa_inputs(is_first_iteration, valid_length, cache_engine_list, + seq_length, valid_batch_flag) + block_tables_np = np.array(block_tables, dtype=np.int32) + block_tables_shm = np.ndarray(block_tables_np.shape, dtype=block_tables_np.dtype, buffer=shms[7].buf) + block_tables_shm[:] = block_tables_np[:] + slot_mapping_np = np.array(slot_mapping, dtype=np.int32) + slot_mapping_shm = np.ndarray(slot_mapping_np.shape, dtype=slot_mapping_np.dtype, buffer=shms[8].buf) + slot_mapping_shm[:] = slot_mapping_np[:] + + shape_strs = [] + for shape in [block_tables_np.shape, slot_mapping_np.shape]: + shape_str = " ".join(str(element) for element in shape) + shape_strs.append(shape_str) + if is_first_iteration: + shapes_str += "," + ",".join(element for element in shape_strs) + else: + shapes_str += "_" + "_".join(element for element in shape_strs) + # logging.debug("get input lite is {} ".format((time.time() - time_start) * 1000)) + # logging.debug("server decode batch size is {} ".format(current_batch_size)) + shapes_str = shapes_str.encode() + + for item in self.agent_stubs: + item.sendall(shapes_str) + recv_data = self.agent_stubs[0].recv(1, socket.MSG_WAITALL).decode() + # if not recv_data=="1": + # recv_data = self.agent_stubs[0].recv(1, socket.MSG_WAITALL).decode() + result = [] + if recv_data == "2": + for _ in decode_index_list: + # result.append(int(Baseconfig.end_token)) + result.append((int(-1),0)) + print("--------------------predict failed, abandon current prompt, please try again----------------") + logging.error("predict failed, abandon current prompt, please try again") + return result, 1 + + ####测试精度时需要注释下面两行 + if is_first_iteration: + return result, 1 + ############ + # save_result = [] # liuyang + for decode_index in decode_index_list: + ### liuyang + # read_time = time.time() + tmp = np.ndarray((decode_index + 1,), dtype=np.int32, buffer=shms[5].buf) + tmp_logprob = np.ndarray((decode_index + 1,), dtype=np.float64, buffer=shms[6].buf) + result.append((int(tmp[decode_index:decode_index + 1]), float(tmp_logprob[decode_index:decode_index + 1]))) + ### liuyang + # save_result.append(int(tmp[decode_index:decode_index + 1])) + # if is_first_iteration: + # file_name = str(read_time)+'result1' + # else: + # file_name = str(read_time)+'result2' + # np.save(file_name, np.array(save_result)) + + logging.info("--------------------callV3 result value is {} ".format(result)) + logging.info("model.call time is {} ".format((time.time() - time_start) * 1000)) + return result, 1 diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/schedule.py" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/schedule.py" new file mode 100644 index 00000000..eed9b84c --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/schedule.py" @@ -0,0 +1,559 @@ +import time +from typing import List, Tuple, Deque +import logging + +from queue import Queue +import copy +from mindspore_serving.serving_utils.constant import * +from mindspore_serving.serving_utils.entry import EntryMetaData, EntryStatus, EntryData +from mindspore_serving.config.config import ServingConfig +from mindspore_serving.schedule.cache_engine import ServingBlockMemPool, ServingCacheEngine + + +class Schedule: + """static batch strategy""" + + def __init__(self, config: ServingConfig): + self.waiting_request_queue: Deque[EntryMetaData] = Deque([]) + self.running_request_list: List[EntryMetaData] = [] + self.count_of_invalid_sample = 0 + self.config = config + self.batch_size = config.model_config.decode_batch_size[0] + self.eos_token = config.model_config.end_token + self.batch_waiting_time = config.serving_config.prefill_batch_waiting_time + self.decode_batch_waiting_time = config.serving_config.decode_batch_waiting_time + self.batching_strategy = config.model_config.batching_strategy + self.max_input_len = config.model_config.seq_length[-1] if len(config.model_config.seq_length) > 0 else 4096 + # batch中有效token的最大index, 初始化为-1 + self.max_valid_index = -1 + self.dyn_batch = config.model_config.decode_batch_size + + def get_dyn_batch(self): + return self.batch_size + + def get_queue_len(self): + return len(self.waiting_request_queue) + + def add_entrys(self, entry_meta_data: EntryMetaData): + entry_meta_data.get_entry_data().set_status(EntryStatus.WAITING) + self.waiting_request_queue.append(entry_meta_data) + + def _padding_batch_size(self): + while len(self.running_request_list) < self.batch_size: + entry_meta_data = copy.deepcopy(self.running_request_list[-1]) + entry_meta_data.entry_data.set_status(EntryStatus.PADDING_INVAILED) + self.running_request_list.append(entry_meta_data) + + def _over_all_complete_entry(self): + for index, _ in enumerate(self.running_request_list): + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + + def _padding_request_into_batching_list(self, index): + if not self.waiting_request_queue: + time.sleep(self.batch_waiting_time / float(len(self.running_request_list))) + if not self.waiting_request_queue: + entry_meta_data = copy.deepcopy(self.running_request_list[-1]) + + if entry_meta_data.entry_data.get_prompt_len() + entry_meta_data.entry_data.get_max_token_len() >= self.max_input_len: + entry_meta_data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + else: + entry_meta_data.get_entry_data().set_status(EntryStatus.PADDING_INVAILED) + + entry_meta_data.get_entry_data().set_decode_index(index) + self.running_request_list.append(entry_meta_data) + logging.debug(f'waiting and add invalid request in batch init, batch size index is {index}') + else: + data = self.waiting_request_queue.popleft() + if data.entry_data.get_prompt_len() + data.entry_data.get_max_token_len() >= self.max_input_len: + data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + else: + data.get_entry_data().set_status(EntryStatus.RUNNING) + + data.get_entry_data().set_decode_index(index) + self.running_request_list.append(data) + # logging.debug(f'add new valid request in batch, batch size index is {index}') + else: + data = self.waiting_request_queue.popleft() + # logging.debug('get_nowait2') + + if data.entry_data.get_prompt_len() + data.entry_data.get_max_token_len() >= self.max_input_len: + data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + else: + data.get_entry_data().set_status(EntryStatus.RUNNING) + + data.get_entry_data().set_decode_index(index) + self.running_request_list.append(data) + # logging.debug(f'add new valid request in batch, batch size index is {index}') + + def _get_next_batch(self): + self.running_request_list.clear() + count = 0 + # no request in schedule queue, return + if not self.waiting_request_queue: + return + # add request into batching list + while self.waiting_request_queue: + if count >= self.batch_size: + break + data = self.waiting_request_queue.popleft() + + if data.entry_data.get_prompt_len() + data.entry_data.get_max_token_len() >= self.max_input_len: + data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + else: + data.get_entry_data().set_status(EntryStatus.RUNNING) + + data.get_entry_data().set_decode_index(count) + self.running_request_list.append(data) + # logging.debug(f'add new valid request in batch, batch size index is {count}') + count += 1 + # if batching list not full, add invalid padding request into batching list + if len(self.running_request_list) < self.batch_size + 1: + for index in range(len(self.running_request_list), self.batch_size): + self._padding_request_into_batching_list(index) + + def _all_samples_in_batch_is_over(self) -> bool: + res = True + for _, data in enumerate(self.running_request_list): + if data.get_entry_data().get_status() == EntryStatus.RUNNING: + res = False + return res + + def checkout_entry(self) -> List[bool]: + """ + request in FINISHED_LENGTH_CAPPED, FINISHED_STOPPED, PADDING_INVAILED status can be cut out + """ + checkout_list = [] + for index, data in enumerate(self.running_request_list): + check_ = False + # max_length, cut out finished request in batch + if data.get_entry_data().get_status() == EntryStatus.FINISHED_LENGTH_CAPPED: + check_ = True + # eos, cut out finished request in batch + elif data.get_entry_data().get_status() == EntryStatus.FINISHED_STOPPED: + check_ = True + elif data.get_entry_data().get_status() == EntryStatus.PADDING_INVAILED: + check_ = True + checkout_list.append(check_) + return checkout_list + + def _padding_new_prompt_to_batch(self, index): + # queue is empty, no new request in schedule queue + if not self.waiting_request_queue: + # waiting + time.sleep(self.batch_waiting_time / float(len(self.running_request_list))) + # no new request, continue finished valid decode + if not self.waiting_request_queue: + # logging.debug('waiting and no new request, continue finished valid decode') + return + # new requestes in queue + else: + # logging.debug('add a new request into batching list') + data = self.waiting_request_queue.popleft() + # logging.debug('get_nowait3') + if data.entry_data.get_prompt_len() + data.entry_data.get_max_token_len() >= self.max_input_len: + data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + + else: + data.get_entry_data().set_status(EntryStatus.RUNNING) + data.get_entry_data().set_decode_index(index) + self.running_request_list[index] = data + # logging.debug(f'add new valid request in batch, batch size index is {index}') + else: + # logging.debug('add a new request into batching list') + data = self.waiting_request_queue.popleft() + # logging.debug('get_nowait4') + if data.entry_data.get_prompt_len() + data.entry_data.get_max_token_len() >= self.max_input_len: + data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + else: + data.get_entry_data().set_status(EntryStatus.RUNNING) + data.get_entry_data().set_decode_index(index) + self.running_request_list[index] = data + # logging.debug(f'add new valid request in batch, batch size index is {index}') + + def _update_status_after_one_itreation(self): + self.count_of_invalid_sample = 0 + """checkout and update number of invalid request in batching list""" + self.max_valid_index = -1 + for index, data in enumerate(self.running_request_list): + data_status = data.get_entry_data().get_status() + if data_status == EntryStatus.FINISHED_STOPPED or data_status == EntryStatus.FINISHED_LENGTH_CAPPED: + self.count_of_invalid_sample += 1 + elif data_status == EntryStatus.RUNNING: + self.max_valid_index = index + + def _determine_batch_size(self): + self._update_status_after_one_itreation() + bf_batch = self.batch_size + queue_len = len(self.waiting_request_queue) + bs_list_len = len(self.dyn_batch) + # 1. 请求队列长度大于当前batch_size,扩容 + if self.max_valid_index == -1 or queue_len > self.batch_size: + # 获取最接近waiting list长度的batch档位 + dyn_index = queue_len + # 2. 请求队列长度小于count_of_invalid_sample,根据max_valid_index动态到邻近档位 + elif queue_len < self.count_of_invalid_sample: + # max_valid_index左侧有多少结束的token + left_free_num = self.count_of_invalid_sample - (self.batch_size - self.max_valid_index - 1) + if queue_len <= left_free_num: + dyn_index = self.max_valid_index + 1 + else: + # 请求队列中全部补齐会到哪个index + dyn_index = queue_len - left_free_num + self.max_valid_index + 1 + else: + dyn_index = self.max_valid_index + 1 + queue_len - self.count_of_invalid_sample + bs_after_changing = self.batch_size + if dyn_index <= 0: + # 默认值 + bs_after_changing = self.dyn_batch[0] + else: + for i in range(1, bs_list_len): + if dyn_index > self.dyn_batch[bs_list_len - i - 1]: + bs_after_changing = self.dyn_batch[bs_list_len - i] + break + self.batch_size = bs_after_changing if bs_after_changing > 0 else self.dyn_batch[0] + af_batch = self.batch_size + if af_batch != bf_batch: + logging.debug(('----bs changed from {} '.format(bf_batch))) + logging.debug(('----bs changed to {} '.format(af_batch))) + if bf_batch >= af_batch: + self.running_request_list = self.running_request_list[:af_batch] + else: + bf_batch = 0 if self.max_valid_index == -1 else bf_batch + block_size = 0 + if self.config.model_config.page_attention: + block_size = self.config.pa_config.block_size + for i in range(bf_batch, af_batch): + entry_meta_data = EntryMetaData(page_attention=self.config.model_config.page_attention, + request_id=PADDING_REQUEST_ID, + is_prompt=True, + entry_data=EntryData(prompt_tokens=[self.eos_token], + max_token_len=5000), + entry_id=-1, + prompt=PADDING_PROMPT, + block_size=block_size) + entry_meta_data.get_entry_data().set_decode_index(i) + entry_meta_data.get_entry_data().set_status(EntryStatus.PADDING_INVAILED) + self.running_request_list.append(entry_meta_data) + + def _continuous_batch(self): + # init batch size when running_request_list is empty. + if len(self.running_request_list) == 0: + self._get_next_batch() + # update invalid request number in batching list. + self._update_status_after_one_itreation() + if self.count_of_invalid_sample == self.batch_size: + self._get_next_batch() + # update status after one inference step + else: + checkout_list = self.checkout_entry() + for index, data in enumerate(checkout_list): + if data and index < self.batch_size: + # logging.debug('----{}-th prefill request in batching padded to batch.'.format(index)) + self._padding_new_prompt_to_batch(index) + + def _insert_new_prompt_to_batch_pa(self, index): + # logging.debug('add a new request into batching list') + # data = self.waiting_request_queue.get_nowait() + data = self.waiting_request_queue.popleft() + if data.entry_data.get_prompt_len() + data.entry_data.get_max_token_len() >= self.max_input_len: + data.get_entry_data().set_status(EntryStatus.INPUT_OUTOFRANGE) + else: + data.get_entry_data().set_status(EntryStatus.RUNNING) + data.get_entry_data().set_decode_index(index) + self.running_request_list[index] = data + # logging.debug(f'add new valid request in batch, batch size index is {index}') + + def try_substitute_entry(self): + checkout_list = self.checkout_entry() + is_invalid_index_list = [] + for index, is_invalid in enumerate(checkout_list): + if is_invalid: + is_invalid_index_list.append(index) + if not is_invalid_index_list: + # logging.debug("no invalid entry to substitute") + return False + # 如果有空槽位,尝试替代一条新请求: + index_to_substitute = is_invalid_index_list[0] + # logging.debug("trying to substitute old entry at index: %s", index_to_substitute) + # 如果新entry需要的block数量,小于can_substitute entry的block数量 + mem pool全局剩余block数量 + new_entry = self.waiting_request_queue[0] + if new_entry.cache_engine.try_use_budget(new_entry.get_entry_data().get_len()): + self._insert_new_prompt_to_batch_pa(index_to_substitute) + return True + # logging.debug("failed inserting to existing entry") + # 如果空间不足,那么连第一条waiting的请求就无法替换,直接退出 + return False + + def reset_all_budgets(self): + # logging.debug("current running list") + for entry in self.running_request_list: + entry.cache_engine.release_budget() + + def can_predict_current_batch(self): + checkout_list = self.checkout_entry() + for index, is_invalid in enumerate(checkout_list): + # 对于batch中running的请求 + if is_invalid: + continue + entry = self.running_request_list[index] + entry_cache_engine = entry.cache_engine + if entry.is_prompt: + if not entry_cache_engine.try_use_budget(entry.get_entry_data().get_len()): + return False + else: + if not entry_cache_engine.try_use_budget(1): + return False + # logging.debug("can decode current batch return true") + return True + + def try_initialize_paddings_pa(self): + # running list和batch size不匹配时,添加padding位补充 + # 场景:1.启动server后,runninglist为空;2.升档后 + # logging.debug("try initialize paddings...") + if len(self.running_request_list) == self.batch_size: + return + elif len(self.running_request_list) > self.batch_size: + raise RuntimeError("running list size: %s larger than batch size: %s!", len(self.running_request_list), + self.batch_size) + block_size = 0 + if self.config.model_config.page_attention: + block_size = self.config.pa_config.block_size + for index in range(len(self.running_request_list), self.batch_size): + padding_entry = EntryMetaData(page_attention=self.config.model_config.page_attention, + request_id=PADDING_REQUEST_ID, + is_prompt=False, # True + entry_data=EntryData(prompt_tokens=[self.config.model_config.end_token], + max_token_len=5000), + entry_id=-1, + prompt=PADDING_PROMPT, + block_size=block_size) + padding_entry.get_entry_data().set_decode_index(index) + padding_entry.get_entry_data().set_status(EntryStatus.PADDING_INVAILED) + cache_engine = padding_entry.cache_engine + cache_engine.assign_null_block() + self.running_request_list.append(padding_entry) + + def insert_padding_entry(self, index): + block_size = 0 + if self.config.model_config.page_attention: + block_size = self.config.pa_config.block_size + padding_entry = EntryMetaData(page_attention=self.config.model_config.page_attention, + request_id=PADDING_REQUEST_ID, + is_prompt=False, # True + entry_data=EntryData(prompt_tokens=[self.config.model_config.end_token], + max_token_len=5000), + entry_id=-1, + prompt=PADDING_PROMPT, + block_size=block_size) + padding_entry.get_entry_data().set_decode_index(index) + padding_entry.get_entry_data().set_status(EntryStatus.PADDING_INVAILED) + cache_engine = padding_entry.cache_engine + cache_engine.assign_null_block() + self.running_request_list[index] = padding_entry + + def try_swap_valid_entries(self): + is_invalid_list = self.checkout_entry() + num_tokens_index_list = [] + for index, is_invalid in enumerate(is_invalid_list): + if is_invalid: + continue + num_tokens_index_list.append((self.running_request_list[index].get_entry_data().get_len(), index)) + if not num_tokens_index_list: + raise RuntimeError("no valid entry to pop!") + + num_tokens_index_list.sort(key=lambda x: x[0]) + _, index_to_swap = num_tokens_index_list[0] + + # 释放一条长度最短的valid entries(认为是最后进来的,TODO:按照时间顺序pop掉最晚进来的entry) + entry_to_swap = self.running_request_list[index_to_swap] + entry_to_swap.get_entry_data().set_status(EntryStatus.WAITING) + entry_to_swap.get_entry_data().set_decode_index(0) + entry_to_swap.is_prompt = True + entry_to_swap.cache_engine.release_cache() + # append回waiting list + # logging.warning("swap entry out, index: %s", index_to_swap) + self.waiting_request_queue.appendleft(entry_to_swap) + # 用padding替代 + # logging.debug("inserting padding to popped entry %s", index_to_swap) + self.insert_padding_entry(index_to_swap) + + def _continuous_batch_pa(self): + ServingBlockMemPool.instance().reset_budget() + # ServingBlockMemPool.instance().log_status() + self.try_initialize_paddings_pa() + # self.log_running_list("schedule start running status") + # 判断batch内的running entry,能否进行本轮推理? + num_entry_swapped_out = 0 + while not self.can_predict_current_batch(): + # 如果不能,swap出去已有请求 + self.reset_all_budgets() + self.try_swap_valid_entries() + num_entry_swapped_out += 1 + if num_entry_swapped_out: + self.reset_all_budgets() + return + # 3. 处理新请求 + # logging.debug("determine if can process new request...") + while self.waiting_request_queue: + # 如果有空batch槽,尝试插入 + # logging.debug("has new entry, trying to enter current batch") + if not self.try_substitute_entry(): + # 尝试失败,退出 + break + self.reset_all_budgets() + # ServingBlockMemPool.instance().log_status() + + def _static_batch(self): + if self._all_samples_in_batch_is_over() or len(self.running_request_list) == 0: + self._get_next_batch() + # updata status after one inference step + self._update_status_after_one_itreation() + # if all samples in batch is invalid status, a static batch is over + if self.count_of_invalid_sample == self.batch_size: + self._get_next_batch() + + def schedule(self) -> Tuple[List[EntryMetaData], int]: + if self.dyn_batch and len(self.dyn_batch) > 1: + self._determine_batch_size() + if self.batching_strategy == 'static': + self._static_batch() + elif not self.config.model_config.page_attention and self.batching_strategy == 'continuous': + self._continuous_batch() + elif self.config.model_config.page_attention: # 加入PA + self._continuous_batch_pa() + else: + raise ValueError("Invalid batching strategy!, please setting static or continuous") + return self.running_request_list, self.batch_size + + # 增加对 PA的_finished_request处理 + def _finished_pa_request(self, index, token, eos_id): + # eos + if token == eos_id: + # logging.debug("a request finished, token equal to {}".format(token)) + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + self.running_request_list[index].cache_engine.release_cache() + self.running_request_list[index].cache_engine.assign_null_block() + return + + # max len + entry_data = self.running_request_list[index].get_entry_data() + if entry_data.max_token_len <= entry_data.get_output_len(): + # logging.debug("a request reached the max generate token length") + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_LENGTH_CAPPED) + self.running_request_list[index].cache_engine.release_cache() + self.running_request_list[index].cache_engine.assign_null_block() + return + + if entry_data.get_len() >= self.config.model_config.max_generate_length: + # logging.debug("a request reached seq len: %s, index: %s", self.config.max_generate_length, index) + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_LENGTH_CAPPED) + self.running_request_list[index].cache_engine.release_cache() + self.running_request_list[index].cache_engine.assign_null_block() + return + + # input outofrange + if entry_data.status == EntryStatus.INPUT_OUTOFRANGE or entry_data.status == EntryStatus.EMPTY_PROMPT_TOKEN: + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + self.running_request_list[index].cache_engine.release_cache() + self.running_request_list[index].cache_engine.assign_null_block() + return + # predict failed + if token == -1: + # logging.debug("a request predict failed, token equal to {}".format(token)) + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + self.running_request_list[index].cache_engine.release_cache() + self.running_request_list[index].cache_engine.assign_null_block() + return + + def _finished_request(self, index, token, eos_id): + # eos + if token == eos_id: + # logging.debug("a request finished, token equal to {}".format(token)) + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + return + + # max len + entry_data = self.running_request_list[index].get_entry_data() + if entry_data.max_token_len <= entry_data.get_output_len(): + # logging.debug("a request reached the max generate token length") + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_LENGTH_CAPPED) + return + + # input outofrange + if entry_data.status == EntryStatus.INPUT_OUTOFRANGE or entry_data.status == EntryStatus.EMPTY_PROMPT_TOKEN: + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + return + # predict failed + if token == -1: + # logging.debug("a request predict failed, token equal to {}".format(token)) + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) + return + + def upate_entries_after_one_step(self, outputs: List[int], eos_id: int, index_list: List[int] = None): + """update status after ever iteration""" + # optimize prefill multi-batch later + if index_list is not None: + # idx: index_list and outputs data index, index: batch list index. + for idx, index in enumerate(index_list): + self.running_request_list[index].is_prompt = False + # invalid prompt + if self.running_request_list[index].get_entry_data().get_status() == EntryStatus.PADDING_INVAILED: + continue + + if self.running_request_list[index].get_entry_data().get_status() == EntryStatus.INPUT_OUTOFRANGE: + update_token = INPUT_OUT_OF_TOKEN[0] + elif self.running_request_list[index].get_entry_data().get_status() == EntryStatus.EMPTY_PROMPT_TOKEN: + update_token = INPUT_EMPTY_TOKEN[0] + else: + update_token = outputs[idx] + + self.running_request_list[index].get_entry_data().updata_output_tokens(update_token) + # valid prompt 区分PA处理 + if self.config.model_config.page_attention: + self._finished_pa_request(index, update_token, eos_id) + else: + self._finished_request(index, update_token, eos_id) + # decode + else: + for index, token in enumerate(outputs): + if self.running_request_list[index].get_entry_data().get_status() != EntryStatus.RUNNING: # 改动 + continue + # update new token to result list + self.running_request_list[index].get_entry_data().updata_output_tokens(token) + + # 区分PA处理 + if self.config.model_config.page_attention: + self._finished_pa_request(index, token, eos_id) + else: + self._finished_request(index, token, eos_id) + + def upate_entries_after_one_step_after_prefill(self, eos_id: int, index_list: List[int] = None): + """update status after ever iteration""" + # optimize prefill multi-batch later + if index_list is not None: + # idx: index_list and outputs data index, index: batch list index. + for idx, index in enumerate(index_list): + self.running_request_list[index].is_prompt = False + # invalid prompt + if self.running_request_list[index].get_entry_data().get_status() == EntryStatus.PADDING_INVAILED: + continue + + if self.running_request_list[index].get_entry_data().get_status() == EntryStatus.INPUT_OUTOFRANGE: + update_token = INPUT_OUT_OF_TOKEN[0] + elif self.running_request_list[index].get_entry_data().get_status() == EntryStatus.EMPTY_PROMPT_TOKEN: + update_token = INPUT_EMPTY_TOKEN[0] + else: + continue + + self.running_request_list[index].get_entry_data().updata_output_tokens(update_token) + # valid prompt 区分PA处理 + if self.config.model_config.page_attention: + self._finished_pa_request(index, update_token, eos_id) + else: + self._finished_request(index, update_token, eos_id) + + def abort_entry(self, + request_id: str): + for index, data in enumerate(self.running_request_list): + if data.request_id == request_id: + self.running_request_list[index].get_entry_data().set_status(EntryStatus.FINISHED_STOPPED) diff --git "a/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/test_serving_performance.py" "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/test_serving_performance.py" new file mode 100644 index 00000000..5fea7576 --- /dev/null +++ "b/2024-ascend-innovation-contest/topic3-inference/second-phase/\350\265\267\351\243\216\344\272\206/test_serving_performance.py" @@ -0,0 +1,228 @@ +import time +import json +import requests +import argparse +import os +import numpy as np +import threading +from mindformers import LlamaTokenizer +import log + +time_now = time.strftime("%Y-%m-%d-%H_%M", time.localtime()) +LOGGER = log.logger_for_test("test_llama", f"./testLog/test_performance_{time_now}.log") +LLAMA2_tokenizer = "./tokenizer.model" # 换模型不需要换tokenizer +RESULT = [] + +CompletedProgress = 0 + + +def init_tokenizer(model_path=LLAMA2_tokenizer): + tokenizer = LlamaTokenizer(model_path) + return tokenizer + + +def get_text_token_num(tokenizer, text): + tokens = tokenizer.tokenize(text) + num_tokens = len(tokens) + # print("token num in text is ", num_tokens) + return num_tokens + + +def poisson_random_s(interval): + poisson_random_ms = np.random.poisson(interval * 1000, 1000)[0] + # LOGGER.info(f"poisson random interval time is {poisson_random_ms / 1000}s") + return poisson_random_ms / 1000 + + +Tokenizer = init_tokenizer() + + +# 延迟tms定时器 +def delayMsecond(t): + t = t * 1000 # 传入s级别 + start, end = 0, 0 + start = time.time_ns() # 精确至ns级别 + while end - start < t * 1000000: + end = time.time_ns() + + +class MyThread(threading.Thread): + def __init__(self, func, args=()): + super(MyThread, self).__init__() + self.func = func + self.args = args + + def run(self): + self.result = self.func(*self.args) + + def get_result(self): + threading.Thread.join(self) + try: + return json.loads(self.result) + except Exception: + return None + + +class LargeModelClient: + def __init__(self, port): + self.url_generate_all = f'http://localhost:{port}/models/llama2/generate' + self.url_generate_stream = f'http://localhost:{port}/models/llama2/generate_stream' + + def send_request(self, testcase, all_counts): + global CompletedProgress + # print("testcase is {}".format(testcase)) + inputs = testcase["input"] + # inputs = "<|User|>:{}\n<|Bot|>:".format(inputs) + body = {"inputs": inputs} + para = {} + return_full_text = testcase["return_full_text"] if "return_full_text" in testcase else False + do_sample = testcase["do_sample"] + max_new_tokens = testcase["max_new_tokens"] if "max_new_tokens" in testcase else False + topk_k = testcase["topk_k"] if "topk_k" in testcase else False + top_p = testcase["top_p"] if "top_p" in testcase else False + temperature = testcase["temperature"] if "temperature" in testcase else False + stream = testcase["stream"] + if max_new_tokens: + para["max_new_tokens"] = max_new_tokens + if temperature: + para["temperature"] = temperature + if topk_k: + para["topk_k"] = topk_k + if top_p: + para["top_p"] = top_p + para["do_sample"] = do_sample + para["return_full_text"] = return_full_text + # print(para) + body["parameters"] = para + if stream: + res = self.return_stream(body, stream) + else: + res = self.return_all(body, stream) + CompletedProgress += 1 + LOGGER.info(f"{res}\nTest Progress --> {CompletedProgress}/{all_counts}") + RESULT.append(res) + return res + + def return_all(self, request_body, stream): + url = self.url_generate_stream if stream else self.url_generate_all + headers = {"Content-Type": "application/json", "Connection": "close"} + start_time = time.time() + resp = requests.request("POST", url, data=json.dumps(request_body), headers=headers) + resp_text = resp.text + resp.close() + res_time = time.time() - start_time + # print(resp_text) + return {"input": request_body["inputs"], + "resp_text": json.loads(resp_text)["generated_text"], + "res_time": res_time} + + def return_stream(self, request_body, stream): + url = self.url_generate_stream if stream else self.url_generate_all + headers = {"Content-Type": "application/json", "Connection": "close"} + start_time = time.time() + resp = requests.request("POST", url, data=json.dumps(request_body), headers=headers, stream=True) + lis = [] + first_token_time = None + for i, line in enumerate(resp.iter_lines(decode_unicode=True)): + if line: + if i == 0: + first_token_time = time.time() - start_time + LOGGER.info(f"first_token_time is {first_token_time}") + # print(json.loads(line)) + # if + print(json.loads(line)["data"][0]["generated_text"]) + lis.append(json.loads(line)["data"][0]["generated_text"]) + # data = json.loads(line) + # print(data['id']) + res_time = time.time() - start_time + if request_body["parameters"]["return_full_text"]: + # print(request_body["parameters"]["return_full_text"]) + resp_text = lis[-1] + # print("******stream full text********") + # print(resp_text) + else: + # print("******stream completeness result********") + resp_text = "".join(lis) + # print("".join(lis)) + return {"input": request_body["inputs"], + "resp_text": resp_text, + "res_time": res_time, + "first_token_time": first_token_time} + + +def generate_thread_tasks(testcases, all_count, port): + client = LargeModelClient(port) + print(all_count) + i = 0 + thread_tasks = [] + k = 0 + while True: + print(k, ":", all_count) + if i > len(testcases) - 1: + i = 0 + thread_tasks.append(MyThread(client.send_request, (testcases[i], all_count))) + i += 1 + k += 1 + if k == all_count: + break + LOGGER.info(f"thread_tasks length is {len(thread_tasks)}") + return thread_tasks + + +def test_main(port, inputs, outputs, x, out_dir, test_all_time=3600): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + print('start Test...') + testcases = [] + for i, input_string in enumerate(inputs): + testcase = {"input": f"{input_string}", "do_sample": "False", "return_full_text": "True", "stream": True, + "max_new_tokens": get_text_token_num(Tokenizer, outputs[i])} + testcases.append(testcase) + # print(testcase) + LOGGER.info(f"testcases length is {len(testcases)}") + # 每次发送的间隔时间 + interval = round(1 / x, 2) + all_counts = int(test_all_time * x) + # 1h内一共需要发送多少次请求 + thread_tasks = generate_thread_tasks(testcases, all_counts, port) + start_time = time.time() + LOGGER.info(f"Start send request, avg interval is {interval}") + for task in thread_tasks: + task.start() + delayMsecond(poisson_random_s(interval)) + + for task in thread_tasks: + task.join() + + end_time = time.time() + LOGGER.info(f"All Tasks Done; Exec Time is {end_time - start_time}") + gen_json = os.path.join(out_dir, f"result_{x}_x.json") + with open(gen_json, "w+") as f: + f.write(json.dumps(RESULT)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="test serving performance") + parser.add_argument("-X", "--qps", help='x req/s', required=True, type=float) + parser.add_argument("-P", "--port", help='port, default is 8000', required=True) + parser.add_argument("-O", "--out_dir", help='dir for saving results', required=True) + parser.add_argument("-T", "--test_time", help='test all time, default 1h', required=False, type=int, default=3600) + args = parser.parse_args() + with open("./alpaca_5010.json") as f: + alpaca_data = json.loads(f.read()) + INPUTS_DATA = [] + OUTPUTS_DATA = [] + count = 0 + input_length = [] + for data in alpaca_data: + count+=1 + if count>1500: + break + input_ = data["instruction"] + ":" + data["input"] if data["input"] else data["instruction"] + INPUTS_DATA.append(input_) + OUTPUTS_DATA.append(data["output"]) + input_length.append(len(input_)) + indexes = np.argsort(input_length) + INPUTS_DATA = [INPUTS_DATA[i] for i in indexes] + OUTPUTS_DATA = [OUTPUTS_DATA[i] for i in indexes] + test_main(args.port, INPUTS_DATA, OUTPUTS_DATA, args.qps, args.out_dir, args.test_time)